From e3a3971a22aabf553256f635c472e36d7074d427 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Tue, 26 May 2020 12:51:30 +0100 Subject: [PATCH 01/88] Added missing operators 16x8 PAD operators needed for keras Mobilenet V1/V2. --- .../lite/kernels/internal/reference/pad.h | 10 +++ tensorflow/lite/kernels/pad.cc | 63 ++++++++++++------- tensorflow/lite/kernels/pad_test.cc | 54 ++++++++++++---- tensorflow/lite/kernels/register.cc | 4 +- tensorflow/lite/toco/tflite/op_version.cc | 2 + .../lite/tools/versioning/op_version.cc | 4 +- 6 files changed, 97 insertions(+), 40 deletions(-) diff --git a/tensorflow/lite/kernels/internal/reference/pad.h b/tensorflow/lite/kernels/internal/reference/pad.h index e20aa5e4b2a..19de2548921 100644 --- a/tensorflow/lite/kernels/internal/reference/pad.h +++ b/tensorflow/lite/kernels/internal/reference/pad.h @@ -168,6 +168,16 @@ inline void PadImageStyle(const tflite::PadParams& op_params, output_data); } +template +inline void PadImageStyle(const tflite::PadParams& op_params, + const RuntimeShape& input_shape, + const int16_t* input_data, const P* pad_value_ptr, + const RuntimeShape& output_shape, + int16_t* output_data) { + Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape, + output_data); +} + template inline void PadImageStyle(const tflite::PadParams& op_params, const RuntimeShape& input_shape, diff --git a/tensorflow/lite/kernels/pad.cc b/tensorflow/lite/kernels/pad.cc index 1bd4f65a043..4a4668c9634 100644 --- a/tensorflow/lite/kernels/pad.cc +++ b/tensorflow/lite/kernels/pad.cc @@ -119,6 +119,42 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { return ResizeOutputTensor(context, &op_context); } +template +TfLiteStatus EvalSignedInt(TfLiteContext* context, const PadContext& op_context, + const tflite::PadParams& op_params) { + integer_type pad_value; + if (op_context.constant_values == nullptr) { + // Quantized Pad requires that 0 is represented in the quantized + // range. + TF_LITE_ENSURE(context, op_context.output->params.zero_point >= + std::numeric_limits::min()); + TF_LITE_ENSURE(context, op_context.output->params.zero_point <= + std::numeric_limits::max()); + pad_value = static_cast(op_context.output->params.zero_point); + } else { + // Quantized Pad requires that 'constant_values' is represented in the + // same quantized range as the input and output tensors. + TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, + op_context.constant_values->params.zero_point); + TF_LITE_ENSURE_EQ(context, op_context.output->params.scale, + op_context.constant_values->params.scale); + pad_value = *GetTensorData(op_context.constant_values); + } + const integer_type pad_value_copy = pad_value; + if (op_context.resizing_category == ResizingCategory::kImageStyle) { + reference_ops::PadImageStyle( + op_params, GetTensorShape(op_context.input), + GetTensorData(op_context.input), &pad_value_copy, + GetTensorShape(op_context.output), + GetTensorData(op_context.output)); + } else { + optimized_ops::Pad(op_params, GetTensorShape(op_context.input), + GetTensorData(op_context.input), + &pad_value_copy, GetTensorShape(op_context.output), + GetTensorData(op_context.output)); + } +} + template TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { PadContext op_context(context, node); @@ -208,29 +244,10 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } break; case kTfLiteInt8: { - int8_t pad_value; - if (op_context.constant_values == nullptr) { - // Quantized Pad requires that 0 is represented in the quantized - // range. - TF_LITE_ENSURE(context, op_context.output->params.zero_point >= - std::numeric_limits::min()); - TF_LITE_ENSURE(context, op_context.output->params.zero_point <= - std::numeric_limits::max()); - pad_value = static_cast(op_context.output->params.zero_point); - } else { - // Quantized Pad requires that 'constant_values' is represented in the - // same quantized range as the input and output tensors. - TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, - op_context.constant_values->params.zero_point); - TF_LITE_ENSURE_EQ(context, op_context.output->params.scale, - op_context.constant_values->params.scale); - pad_value = *GetTensorData(op_context.constant_values); - } - if (op_context.resizing_category == ResizingCategory::kImageStyle) { - TF_LITE_PAD(reference_ops, PadImageStyle, int8_t, pad_value); - } else { - TF_LITE_PAD(reference_ops, Pad, int8_t, pad_value); - } + EvalSignedInt(context, op_context, op_params); + } break; + case kTfLiteInt16: { + EvalSignedInt(context, op_context, op_params); } break; case kTfLiteInt32: { int32_t pad_value = diff --git a/tensorflow/lite/kernels/pad_test.cc b/tensorflow/lite/kernels/pad_test.cc index 8ef03290531..7a6ac06ec57 100644 --- a/tensorflow/lite/kernels/pad_test.cc +++ b/tensorflow/lite/kernels/pad_test.cc @@ -12,8 +12,8 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include #include +#include #include "tensorflow/lite/interpreter.h" #include "tensorflow/lite/kernels/register.h" #include "tensorflow/lite/kernels/test_util.h" @@ -332,20 +332,27 @@ TEST_F(QuantizedPadOpTest, UInt8ZeroNotInQuantizationRange) { TEST_F(QuantizedPadOpTest, Int8ZeroNotInQuantizationRange) { ZeroNotInQuantizationRange(); } +TEST_F(QuantizedPadOpTest, Int16ZeroNotInQuantizationRange) { + ZeroNotInQuantizationRange(); +} #endif template void SimpleConstTest() { // Padding is represented as four 2-D lists representing above padding and // below padding (i.e. {{0, 0}, {1, 1}, {1, 1}, {0, 0}}). - PadOpConstModel m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, - {0, 0, 1, 1, 1, 1, 0, 0}, {tensor_dtype, {}, -1.0, 1.0}); + + const float kMin = -1.f; + const float kMax = tensor_dtype == TensorType_INT16 ? 32767.f / 32768.f : 1.f; + + PadOpConstModel m({tensor_dtype, {1, 2, 2, 1}, kMin, kMax}, {4, 2}, + {0, 0, 1, 1, 1, 1, 0, 0}, {tensor_dtype, {}, kMin, kMax}); m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); m.Invoke(); EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, - -1.0, 1.0))); + kMin, kMax))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } @@ -355,18 +362,24 @@ TEST_F(QuantizedPadOpTest, UInt8SimpleConstTest) { TEST_F(QuantizedPadOpTest, Int8SimpleConstTest) { SimpleConstTest(); } +TEST_F(QuantizedPadOpTest, Int16SimpleConstTest) { + SimpleConstTest(); +} template void SimpleDynamicTest() { - PadOpDynamicModel m({tensor_dtype, {1, 2, 2, 1}, -1.0, 1.0}, {4, 2}, - {tensor_dtype, {}, -1.0, 1.0}); + const float kMin = -1.f; + const float kMax = tensor_dtype == TensorType_INT16 ? 32767.f / 32768.f : 1.f; + + PadOpDynamicModel m({tensor_dtype, {1, 2, 2, 1}, kMin, kMax}, {4, 2}, + {tensor_dtype, {}, kMin, kMax}); m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7}); m.SetPaddings({0, 0, 1, 1, 1, 1, 0, 0}); m.Invoke(); EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, 0, 0, 0, 0, -0.8, 0.2, 0, 0, 0.9, 0.7, 0, 0, 0, 0, 0}, - -1.0, 1.0))); + kMin, kMax))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 4, 1})); } @@ -376,18 +389,24 @@ TEST_F(QuantizedPadOpTest, UInt8SimpleDynamicTest) { TEST_F(QuantizedPadOpTest, Int8SimpleDynamicTest) { SimpleDynamicTest(); } +TEST_F(QuantizedPadOpTest, Int16SimpleDynamicTest) { + SimpleDynamicTest(); +} template void AdvancedConstTest() { - PadOpConstModel m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, - {0, 0, 0, 2, 1, 3, 0, 0}, {tensor_dtype, {}, -1.0, 1.0}); + const float kMin = -1.f; + const float kMax = tensor_dtype == TensorType_INT16 ? 32767.f / 32768.f : 1.f; + + PadOpConstModel m({tensor_dtype, {1, 2, 3, 1}, kMin, kMax}, {4, 2}, + {0, 0, 0, 2, 1, 3, 0, 0}, {tensor_dtype, {}, kMin, kMax}); m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); m.Invoke(); EXPECT_THAT(m.template GetDequantizedOutput(), ElementsAreArray(DequantizedArrayNear( {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - -1.0, 1.0))); + kMin, kMax))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } @@ -397,11 +416,17 @@ TEST_F(QuantizedPadOpTest, UInt8AdvancedConstTest) { TEST_F(QuantizedPadOpTest, Int8AdvancedConstTest) { AdvancedConstTest(); } +TEST_F(QuantizedPadOpTest, Int16AdvancedConstTest) { + AdvancedConstTest(); +} template void AdvancedDynamicTest() { - PadOpDynamicModel m({tensor_dtype, {1, 2, 3, 1}, -1.0, 1.0}, {4, 2}, - {tensor_dtype, {}, -1.0, 1.0}); + const float kMin = -1.f; + const float kMax = tensor_dtype == TensorType_INT16 ? 32767.f / 32768.f : 1.f; + + PadOpDynamicModel m({tensor_dtype, {1, 2, 3, 1}, kMin, kMax}, {4, 2}, + {tensor_dtype, {}, kMin, kMax}); m.template SetQuantizedInput({-0.8, 0.2, 0.9, 0.7, 0.1, -0.3}); m.SetPaddings({0, 0, 0, 2, 1, 3, 0, 0}); m.Invoke(); @@ -409,7 +434,7 @@ void AdvancedDynamicTest() { ElementsAreArray(DequantizedArrayNear( {0, -0.8, 0.2, 0.9, 0, 0, 0, 0, 0.7, 0.1, -0.3, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, - -1.0, 1.0))); + kMin, kMax))); EXPECT_THAT(m.GetOutputShape(), ElementsAreArray({1, 4, 7, 1})); } @@ -419,6 +444,9 @@ TEST_F(QuantizedPadOpTest, UInt8AdvancedDynamicTest) { TEST_F(QuantizedPadOpTest, Int8AdvancedDynamicTest) { AdvancedDynamicTest(); } +TEST_F(QuantizedPadOpTest, Int16AdvancedDynamicTest) { + AdvancedDynamicTest(); +} #ifdef GTEST_HAS_DEATH_TEST TEST(PadV2OpTest, TooManyDimensions) { diff --git a/tensorflow/lite/kernels/register.cc b/tensorflow/lite/kernels/register.cc index 8ca58e6a309..7dc2ca169a6 100644 --- a/tensorflow/lite/kernels/register.cc +++ b/tensorflow/lite/kernels/register.cc @@ -111,9 +111,9 @@ BuiltinOpResolver::BuiltinOpResolver() { Register_UNIDIRECTIONAL_SEQUENCE_LSTM(), /* min_version = */ 1, /* max_version = */ 2); AddBuiltin(BuiltinOperator_PAD, Register_PAD(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_PADV2, Register_PADV2(), /* min_version = */ 1, - /* max_version = */ 2); + /* max_version = */ 3); AddBuiltin(BuiltinOperator_RESHAPE, Register_RESHAPE()); AddBuiltin(BuiltinOperator_RESIZE_BILINEAR, Register_RESIZE_BILINEAR(), /* min_version = */ 1, diff --git a/tensorflow/lite/toco/tflite/op_version.cc b/tensorflow/lite/toco/tflite/op_version.cc index cf127a9f459..00857891bfe 100644 --- a/tensorflow/lite/toco/tflite/op_version.cc +++ b/tensorflow/lite/toco/tflite/op_version.cc @@ -104,10 +104,12 @@ string GetMinimumRuntimeVersionForModel(const Model& model) { {{OperatorType::kMul, 4}, kPendingReleaseOpVersion}, {{OperatorType::kPad, 1}, "1.5.0"}, {{OperatorType::kPad, 2}, "1.14.0"}, + {{OperatorType::kPad, 3}, kPendingReleaseOpVersion}, {{OperatorType::kTile, 1}, "1.10.1"}, {{OperatorType::kTile, 2}, kPendingReleaseOpVersion}, {{OperatorType::kPadV2, 1}, "1.9.0"}, {{OperatorType::kPadV2, 2}, "1.14.0"}, + {{OperatorType::kPadV2, 3}, kPendingReleaseOpVersion}, {{OperatorType::kReshape, 1}, "1.5.0"}, {{OperatorType::kSoftmax, 1}, "1.5.0"}, {{OperatorType::kSoftmax, 2}, "1.14.0"}, diff --git a/tensorflow/lite/tools/versioning/op_version.cc b/tensorflow/lite/tools/versioning/op_version.cc index 118e2d420f8..adf6b1d247b 100644 --- a/tensorflow/lite/tools/versioning/op_version.cc +++ b/tensorflow/lite/tools/versioning/op_version.cc @@ -490,6 +490,8 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { case BuiltinOperator_CONCATENATION: case BuiltinOperator_SOFTMAX: + case BuiltinOperator_PAD: + case BuiltinOperator_PADV2: // In case of int16 inputs, the version is 3. if (op_sig.input_types.at(0) == TensorType_INT16) { return 3; @@ -500,8 +502,6 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) { return 1; case BuiltinOperator_ADD: - case BuiltinOperator_PAD: - case BuiltinOperator_PADV2: case BuiltinOperator_SPACE_TO_DEPTH: case BuiltinOperator_SPLIT_V: case BuiltinOperator_MEAN: From 8956cd169b4d6cc0a8d50dceb1f4da72b6515b39 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Wed, 27 May 2020 10:35:53 +0100 Subject: [PATCH 02/88] Excluded new tests for int16 from tests for acceleration. --- tensorflow/lite/delegates/nnapi/acceleration_test_list.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index 46a6a720d1e..de9ce3dc4b9 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -258,6 +258,8 @@ NegOpModel/.+,29 QuantizedPadOpTest/.+,29 QuantizedPadV2OpTest/.+,29 PadOpTest/.+,29 +# 16-bit tests are not supported +-QuantizedPadOpTest/Int16.+ # pooling_test FloatPoolingOpTest/L2PoolActivationRelu.*,29 From 829277a571dc2f93f82a73a3cf570db33bcf73ce Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Thu, 4 Jun 2020 10:44:35 +0100 Subject: [PATCH 03/88] Addressed reviewer's comments. Change-Id: Ibc97958cec422159f70ee2c4a8cbc723dc00f883 --- .../lite/kernels/internal/reference/pad.h | 32 ------------------- 1 file changed, 32 deletions(-) diff --git a/tensorflow/lite/kernels/internal/reference/pad.h b/tensorflow/lite/kernels/internal/reference/pad.h index 19de2548921..20fe3434ae5 100644 --- a/tensorflow/lite/kernels/internal/reference/pad.h +++ b/tensorflow/lite/kernels/internal/reference/pad.h @@ -137,43 +137,11 @@ inline void Pad(const tflite::PadParams& op_params, output_data); } -// One could make all PadImageStyle calls simply delegate the work to the -// ordinary Pad. However, it is better that the reference code asserts false in -// similar cases. template inline void PadImageStyle(const tflite::PadParams& op_params, const RuntimeShape& input_shape, const T* input_data, const P* pad_value_ptr, const RuntimeShape& output_shape, T* output_data) { - TFLITE_ASSERT_FALSE; -} - -template -inline void PadImageStyle(const tflite::PadParams& op_params, - const RuntimeShape& input_shape, - const uint8* input_data, const P* pad_value_ptr, - const RuntimeShape& output_shape, - uint8* output_data) { - Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape, - output_data); -} - -template -inline void PadImageStyle(const tflite::PadParams& op_params, - const RuntimeShape& input_shape, - const int8_t* input_data, const P* pad_value_ptr, - const RuntimeShape& output_shape, - int8_t* output_data) { - Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape, - output_data); -} - -template -inline void PadImageStyle(const tflite::PadParams& op_params, - const RuntimeShape& input_shape, - const int16_t* input_data, const P* pad_value_ptr, - const RuntimeShape& output_shape, - int16_t* output_data) { Pad(op_params, input_shape, input_data, pad_value_ptr, output_shape, output_data); } From 165c8c5dbd8ac183bd322266798321adb4886f45 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Thu, 4 Jun 2020 11:05:48 +0100 Subject: [PATCH 04/88] Addressed reviewer's comments. Change-Id: If8022418adcc6b6a93354625476f32155dd53d36 --- tensorflow/lite/kernels/pad.cc | 40 +++++----------------------------- 1 file changed, 5 insertions(+), 35 deletions(-) diff --git a/tensorflow/lite/kernels/pad.cc b/tensorflow/lite/kernels/pad.cc index 4a4668c9634..4c4caeea853 100644 --- a/tensorflow/lite/kernels/pad.cc +++ b/tensorflow/lite/kernels/pad.cc @@ -120,8 +120,8 @@ TfLiteStatus Prepare(TfLiteContext* context, TfLiteNode* node) { } template -TfLiteStatus EvalSignedInt(TfLiteContext* context, const PadContext& op_context, - const tflite::PadParams& op_params) { +TfLiteStatus EvalInt(TfLiteContext* context, const PadContext& op_context, + const tflite::PadParams& op_params) { integer_type pad_value; if (op_context.constant_values == nullptr) { // Quantized Pad requires that 0 is represented in the quantized @@ -211,43 +211,13 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) { } } break; case kTfLiteUInt8: { - uint8_t pad_value; - if (op_context.constant_values == nullptr) { - // Quantized Pad requires that 0 is represented in the quantized - // range. - TF_LITE_ENSURE(context, op_context.output->params.zero_point >= - std::numeric_limits::min()); - TF_LITE_ENSURE(context, op_context.output->params.zero_point <= - std::numeric_limits::max()); - pad_value = static_cast(op_context.output->params.zero_point); - } else { - // Quantized Pad requires that 'constant_values' is represented in the - // same quantized range as the input and output tensors. - TF_LITE_ENSURE_EQ(context, op_context.output->params.zero_point, - op_context.constant_values->params.zero_point); - TF_LITE_ENSURE_EQ(context, op_context.output->params.scale, - op_context.constant_values->params.scale); - pad_value = *GetTensorData(op_context.constant_values); - } - if (kernel_type == kReference) { - if (op_context.resizing_category == ResizingCategory::kImageStyle) { - TF_LITE_PAD(reference_ops, PadImageStyle, uint8_t, pad_value); - } else { - TF_LITE_PAD(reference_ops, Pad, uint8_t, pad_value); - } - } else if (kernel_type == kGenericOptimized) { - if (op_context.resizing_category == ResizingCategory::kImageStyle) { - TF_LITE_PAD(optimized_ops, PadImageStyle, uint8_t, pad_value); - } else { - TF_LITE_PAD(optimized_ops, Pad, uint8_t, pad_value); - } - } + EvalInt(context, op_context, op_params); } break; case kTfLiteInt8: { - EvalSignedInt(context, op_context, op_params); + EvalInt(context, op_context, op_params); } break; case kTfLiteInt16: { - EvalSignedInt(context, op_context, op_params); + EvalInt(context, op_context, op_params); } break; case kTfLiteInt32: { int32_t pad_value = From 1267401ad07c9f0296490a980e9d5297862323b0 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Wed, 10 Jun 2020 12:54:09 +0100 Subject: [PATCH 05/88] Fix for CI error. Change-Id: If577ceb08baaaa8aad3554aa54c1ffddf3f26825 --- tensorflow/lite/kernels/pad.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/lite/kernels/pad.cc b/tensorflow/lite/kernels/pad.cc index 4c4caeea853..b51c5260938 100644 --- a/tensorflow/lite/kernels/pad.cc +++ b/tensorflow/lite/kernels/pad.cc @@ -153,6 +153,8 @@ TfLiteStatus EvalInt(TfLiteContext* context, const PadContext& op_context, &pad_value_copy, GetTensorShape(op_context.output), GetTensorData(op_context.output)); } + + return kTfLiteOk; } template From 9616883b82ee81b4b03faafb219d5aa23c275ca8 Mon Sep 17 00:00:00 2001 From: Yong Tang Date: Thu, 18 Jun 2020 16:17:52 +0000 Subject: [PATCH 06/88] Add complex64 and complex128 gpu support for tensor_scatter_nd_add This PR adds complex64 and complex128 gpu support for tensor_scatter_nd_add, as was raised in 40577. This PR fixes 40577. Signed-off-by: Yong Tang --- tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc index bfffadecdf1..8bee7dbee67 100644 --- a/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc +++ b/tensorflow/core/kernels/inplace_ops_functor_gpu.cu.cc @@ -197,6 +197,8 @@ Status DoCopy(const Device& d, const Tensor& x, Tensor* y) { CASE(float) CASE(double) CASE(Eigen::half) + CASE(complex64) + CASE(complex128) CASE(int64) #undef CASE default: From e5e33d9995804e4cf2e071d86442854bf38e27ff Mon Sep 17 00:00:00 2001 From: Vishal Subbiah Date: Mon, 22 Jun 2020 12:01:48 -0700 Subject: [PATCH 07/88] supporting x dtype instead of only float 32 --- tensorflow/python/keras/backend.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 9330425272f..ee32ac41dd3 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -3452,7 +3452,7 @@ _VALUE_SET_CODE_STRING = """ >>> print(K.get_value(v)) 3.0 - Variable semantics in TensorFlow 2 are eager execution friendly. The above + Variable semantics in TensorFlow 2 are eager execution friendly. The above code is roughly equivalent to: >>> v = tf.Variable(1.) @@ -4555,7 +4555,7 @@ def relu(x, alpha=0., max_value=None, threshold=0): if threshold != 0: # computes x for x > threshold else 0 - x = x * math_ops.cast(math_ops.greater(x, threshold), floatx()) + x = x * math_ops.cast(math_ops.greater(x, threshold), x.dtype) elif max_value == 6: # if no threshold, then can use nn.relu6 native TF op for performance x = nn.relu6(x) From 8010ff21a88abf5a455dc29836907e6f30657dfb Mon Sep 17 00:00:00 2001 From: Vishal Subbiah Date: Mon, 22 Jun 2020 23:15:31 -0700 Subject: [PATCH 08/88] handling lists and tuples --- tensorflow/python/keras/backend.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index ee32ac41dd3..b36bc0839b8 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4541,7 +4541,10 @@ def relu(x, alpha=0., max_value=None, threshold=0): Returns: A tensor. """ - + if isinstance(x, tf.Tensor) or isinstance(x, np.ndarray): + dtype = x.dtype + else: + dtype = floatx() if alpha != 0.: if max_value is None and threshold == 0: return nn.leaky_relu(x, alpha=alpha) @@ -4555,7 +4558,7 @@ def relu(x, alpha=0., max_value=None, threshold=0): if threshold != 0: # computes x for x > threshold else 0 - x = x * math_ops.cast(math_ops.greater(x, threshold), x.dtype) + x = x * math_ops.cast(math_ops.greater(x, threshold), dtype=dtype) elif max_value == 6: # if no threshold, then can use nn.relu6 native TF op for performance x = nn.relu6(x) From 719dbabcb27accb3c16375398a6e6bfa7d0433e6 Mon Sep 17 00:00:00 2001 From: Vishal Subbiah Date: Mon, 22 Jun 2020 23:16:41 -0700 Subject: [PATCH 09/88] new line --- tensorflow/python/keras/backend.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index b36bc0839b8..0c1b660855e 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4541,6 +4541,7 @@ def relu(x, alpha=0., max_value=None, threshold=0): Returns: A tensor. """ + if isinstance(x, tf.Tensor) or isinstance(x, np.ndarray): dtype = x.dtype else: From 1fd1af01a95884fcf1c74d36c2890f7e3be7c419 Mon Sep 17 00:00:00 2001 From: Elena Zhelezina Date: Tue, 30 Jun 2020 12:24:40 +0100 Subject: [PATCH 10/88] Addressed reviewers comments. Change-Id: I542064cffe6079a2f77c316812fa691042c2f9d6 --- tensorflow/lite/tools/versioning/runtime_version.cc | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/lite/tools/versioning/runtime_version.cc b/tensorflow/lite/tools/versioning/runtime_version.cc index 92a7001606f..5b5b3dd4985 100644 --- a/tensorflow/lite/tools/versioning/runtime_version.cc +++ b/tensorflow/lite/tools/versioning/runtime_version.cc @@ -133,10 +133,12 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code, {{BuiltinOperator_NON_MAX_SUPPRESSION_V5, 1}, "2.1.0"}, {{BuiltinOperator_PAD, 1}, "1.5.0"}, {{BuiltinOperator_PAD, 2}, "1.14.0"}, + {{BuiltinOperator_PAD, 3}, kPendingReleaseVersion}, {{BuiltinOperator_TILE, 1}, "1.10.1"}, {{BuiltinOperator_TILE, 2}, "2.2.0"}, {{BuiltinOperator_PADV2, 1}, "1.9.0"}, {{BuiltinOperator_PADV2, 2}, "1.14.0"}, + {{BuiltinOperator_PADV2, 3}, kPendingReleaseVersion}, {{BuiltinOperator_RESHAPE, 1}, "1.5.0"}, {{BuiltinOperator_SOFTMAX, 1}, "1.5.0"}, {{BuiltinOperator_SOFTMAX, 2}, "1.14.0"}, From e944e388ad52cc99bfdaa48d2b760c30cc9e22d2 Mon Sep 17 00:00:00 2001 From: Vishal Subbiah Date: Tue, 30 Jun 2020 11:19:45 -0700 Subject: [PATCH 11/88] Update backend.py --- tensorflow/python/keras/backend.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 0c1b660855e..d4195338a78 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4541,8 +4541,10 @@ def relu(x, alpha=0., max_value=None, threshold=0): Returns: A tensor. """ - - if isinstance(x, tf.Tensor) or isinstance(x, np.ndarray): + if isinstance(x, (ops.Tensor, + variables_module.Variable, + sparse_tensor.SparseTensor, + np.ndarray)): dtype = x.dtype else: dtype = floatx() From 31b48f4d235909dae621cee4c1fc98898d99f0a7 Mon Sep 17 00:00:00 2001 From: Vishal Subbiah Date: Mon, 6 Jul 2020 10:35:29 -0700 Subject: [PATCH 12/88] addressing feedback --- tensorflow/python/keras/backend.py | 5 +---- 1 file changed, 1 insertion(+), 4 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index d4195338a78..6c3e70a02af 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4541,10 +4541,7 @@ def relu(x, alpha=0., max_value=None, threshold=0): Returns: A tensor. """ - if isinstance(x, (ops.Tensor, - variables_module.Variable, - sparse_tensor.SparseTensor, - np.ndarray)): + if hasattr(x, "dtype"): dtype = x.dtype else: dtype = floatx() From a9f9c45eb514fd7a40a8e0eb11085749130e5670 Mon Sep 17 00:00:00 2001 From: Vishal Subbiah Date: Mon, 6 Jul 2020 13:35:16 -0700 Subject: [PATCH 13/88] adding comment. --- tensorflow/python/keras/backend.py | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index 6c3e70a02af..121725fb416 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4541,10 +4541,10 @@ def relu(x, alpha=0., max_value=None, threshold=0): Returns: A tensor. """ - if hasattr(x, "dtype"): - dtype = x.dtype - else: - dtype = floatx() + # While x can be a tensor or variable, we also see cases where + # numpy arrays, lists, tuples are passed as well. + # lists, tuples do not have 'dtype' attribute. + dtype = getattr(x, 'dtype', floatx()) if alpha != 0.: if max_value is None and threshold == 0: return nn.leaky_relu(x, alpha=alpha) From 068488af7ee26dbfca9d84c30276429d86a53da6 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Mon, 6 Jul 2020 21:14:27 -0400 Subject: [PATCH 14/88] Simplify code for keras_cpu_benchmark_test --- tensorflow/python/keras/benchmarks/BUILD | 2 + .../benchmarks/keras_cpu_benchmark_test.py | 148 ++++++++---------- 2 files changed, 66 insertions(+), 84 deletions(-) diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD index 9e0ae9194a5..dbdf7de87a8 100755 --- a/tensorflow/python/keras/benchmarks/BUILD +++ b/tensorflow/python/keras/benchmarks/BUILD @@ -29,6 +29,8 @@ py_test( srcs = ["keras_cpu_benchmark_test.py"], python_version = "PY3", deps = [ + ":benchmark_util", + "//tensorflow/python/keras", "//tensorflow:tensorflow_py", "//third_party/py/numpy", ], diff --git a/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py b/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py index b214df04746..63eedfcd11c 100644 --- a/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py +++ b/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py @@ -13,13 +13,10 @@ # limitations under the License. # ============================================================================== """Benchmark tests for CPU performance of Keras models.""" - from __future__ import absolute_import from __future__ import division from __future__ import print_function -import timeit - import numpy as np import six @@ -27,90 +24,38 @@ import tensorflow as tf from tensorflow.python.platform import benchmark from tensorflow.python.platform import test +from tensorflow.python.keras.benchmarks import benchmark_util -_NUM_EPOCHS = 4 - -# Dataset for benchmark -_MLP_X = np.random.random((5000, 784)) -_MLP_Y = np.random.random((5000, 10)) - -_CONVNET_X = np.random.random((5000, 28, 28, 1)) -_CONVNET_Y = np.random.random((5000, 10)) - -_LSTM_X = np.random.randint(0, 1999, size=(2500, 100)) -_LSTM_Y = np.random.random((2500, 1)) - - -class TimerCallback(tf.keras.callbacks.Callback): - - def __init__(self): - self.times = [] - self.timer = timeit.default_timer - self.startup_time = timeit.default_timer() - self.recorded_startup = False - - def on_epoch_begin(self, e, logs): - self.epoch_start_time = self.timer() - - def on_batch_end(self, e, logs): - if not self.recorded_startup: - self.startup_time = self.timer() - self.startup_time - self.recorded_startup = True - - def on_epoch_end(self, e, logs): - self.times.append(self.timer() - self.epoch_start_time) +# Loss function and optimizer. +_LOSS = 'binary_crossentropy' +_OPTIMIZER = 'rmsprop' class KerasModelCPUBenchmark( six.with_metaclass(benchmark.ParameterizedBenchmark, test.Benchmark)): - - # Set parameters for paramerized benchmark. + """Required Arguments for measure_performance: + x: Input data, it could be Numpy or load from tfds. + y: Target data. If `x` is a dataset, generator instance, + `y` should not be specified. + loss: Loss function for model. + optimizer: Optimizer for model. + Other details can see in `measure_performance()` method of + benchmark_util. + """ + """The parameters of each benchmark is a tuple: + (benchmark_name_suffix, batch_size, run_iters). + benchmark_name_suffix: The suffix of the benchmark test name with + convention `{bs}_{batch_size}`. + batch_size: Integer. Number of samples per gradient update. + run_iters: Integer. Number of iterations to run the + performance measurement. + """ _benchmark_parameters = [ ('bs_32', 32, 3), ('bs_64', 64, 2), ('bs_128', 128, 2), ('bs_256', 256, 1), ('bs_512', 512, 1)] - def _measure_performance(self, model_fn, x, y, batch_size=32, - run_iters=4): - build_time_list, compile_time_list, startup_time_list = [], [], [] - avg_epoch_time_list, wall_time_list, exp_per_sec_list = [], [], [] - total_num_examples = y.shape[0] * _NUM_EPOCHS - - for _ in range(run_iters): - timer = timeit.default_timer - t0 = timer() - model = model_fn() - build_time = timer() - t0 - - t1 = timer() - model.compile('rmsprop', 'binary_crossentropy') - compile_time = timer() - t1 - - cbk = TimerCallback() - t2 = timer() - model.fit(x, y, epochs=_NUM_EPOCHS, batch_size=batch_size, - callbacks=[cbk], verbose=0) - end_time = timer() - - build_time_list.append(build_time) - compile_time_list.append(compile_time) - startup_time_list.append(cbk.startup_time) - avg_epoch_time_list.append(np.mean(cbk.times[1:])) - wall_time_list.append(end_time - t0) - exp_per_sec_list.append(total_num_examples / (end_time - t2)) - - results = {'build_time': np.mean(build_time_list), - 'compile_time': np.mean(compile_time_list), - 'startup_time': np.mean(startup_time_list), - 'avg_epoch_time': np.mean(avg_epoch_time_list), - 'wall_time': np.mean(wall_time_list), - 'exp_per_sec': np.mean(exp_per_sec_list)} - - self.report_benchmark( - iters=_NUM_EPOCHS, - wall_time=results['wall_time'], - extras=results) - def _mnist_mlp(self): + """Simple MLP model.""" model = tf.keras.Sequential() model.add(tf.keras.layers.Dense(512, activation='relu', input_shape=(784,))) model.add(tf.keras.layers.Dropout(0.2)) @@ -121,6 +66,7 @@ class KerasModelCPUBenchmark( return model def _mnist_convnet(self): + """Simple Convnet model.""" model = tf.keras.Sequential() model.add( tf.keras.layers.Conv2D( @@ -136,6 +82,7 @@ class KerasModelCPUBenchmark( return model def _imdb_lstm(self): + """Simple LSTM model.""" model = tf.keras.Sequential() model.add(tf.keras.layers.Embedding(20000, 128)) model.add(tf.keras.layers.LSTM(128, dropout=0.2, recurrent_dropout=0.2)) @@ -144,17 +91,50 @@ class KerasModelCPUBenchmark( return model def benchmark_mnist_mlp(self, batch_size, run_iters): - self._measure_performance(self._mnist_mlp, _MLP_X, _MLP_Y, - batch_size=batch_size, run_iters=run_iters) + """Benchmark for MLP model on synthetic mnist data.""" + mlp_x = np.random.random((5000, 784)) + mlp_y = np.random.random((5000, 10)) + results = benchmark_util.measure_performance(self._mnist_mlp, + x=mlp_x, + y=mlp_y, + batch_size=batch_size, + run_iters=run_iters, + optimizer=_OPTIMIZER, + loss=_LOSS) + self.report_benchmark(iters=run_iters, + wall_time=results['wall_time'], + extras=results) def benchmark_mnist_convnet(self, batch_size, run_iters): - self._measure_performance(self._mnist_convnet, _CONVNET_X, _CONVNET_Y, - batch_size=batch_size, run_iters=run_iters) + """Benchmark for Convnet model on synthetic mnist data.""" + convnet_x = np.random.random((5000, 28, 28, 1)) + convnet_y = np.random.random((5000, 10)) + results = benchmark_util.measure_performance(self._mnist_convnet, + x=convnet_x, + y=convnet_y, + batch_size=batch_size, + run_iters=run_iters, + optimizer=_OPTIMIZER, + loss=_LOSS) + self.report_benchmark(iters=run_iters, + wall_time=results['wall_time'], + extras=results) def benchmark_imdb_lstm(self, batch_size, run_iters): - self._measure_performance(self._imdb_lstm, _LSTM_X, _LSTM_Y, - batch_size=batch_size, run_iters=run_iters) + """Benchmark for LSTM model on synthetic imdb review dataset.""" + lstm_x = np.random.randint(0, 1999, size=(2500, 100)) + lstm_y = np.random.random((2500, 1)) + results = benchmark_util.measure_performance(self._imdb_lstm(), + x=lstm_x, + y=lstm_y, + batch_size=batch_size, + run_iters=run_iters, + optimizer=_OPTIMIZER, + loss=_LOSS) + self.report_benchmark(iters=run_iters, + wall_time=results['wall_time'], + extras=results) if __name__ == '__main__': - test.main() + test.main() \ No newline at end of file From 82de1484e50734dd7555a1ee39f7930923b0c69f Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Mon, 6 Jul 2020 22:12:02 -0400 Subject: [PATCH 15/88] Update BUILD using buildifier tool. --- tensorflow/python/keras/benchmarks/BUILD | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) mode change 100755 => 100644 tensorflow/python/keras/benchmarks/BUILD diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD old mode 100755 new mode 100644 index dbdf7de87a8..682d1b24be8 --- a/tensorflow/python/keras/benchmarks/BUILD +++ b/tensorflow/python/keras/benchmarks/BUILD @@ -30,8 +30,8 @@ py_test( python_version = "PY3", deps = [ ":benchmark_util", - "//tensorflow/python/keras", "//tensorflow:tensorflow_py", + "//tensorflow/python/keras", "//third_party/py/numpy", ], ) From e25f9aa53ffb722f6cfb76bb7f7ab431b7f1c773 Mon Sep 17 00:00:00 2001 From: Vo Van Nghia Date: Tue, 7 Jul 2020 21:48:51 +0700 Subject: [PATCH 16/88] Add ram file block cache --- .../experimental/filesystem/plugins/gcs/BUILD | 19 ++ .../filesystem/plugins/gcs/cleanup.h | 111 ++++++ .../filesystem/plugins/gcs/file_block_cache.h | 3 + .../plugins/gcs/ram_file_block_cache.cc | 317 ++++++++++++++++++ .../plugins/gcs/ram_file_block_cache.h | 266 +++++++++++++++ 5 files changed, 716 insertions(+) create mode 100644 tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h create mode 100644 tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc create mode 100644 tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD index 2a886dee4cb..514752a3b90 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/BUILD @@ -52,6 +52,25 @@ cc_library( ], ) +cc_library( + name = "cleanup", + hdrs = ["cleanup.h"], +) + +cc_library( + name = "ram_file_block_cache", + srcs = ["ram_file_block_cache.cc"], + hdrs = ["ram_file_block_cache.h"], + deps = [ + ":cleanup", + ":file_block_cache", + "//tensorflow/c:env", + "//tensorflow/c:tf_status", + "@com_google_absl//absl/base", + "@com_google_absl//absl/synchronization", + ], +) + tf_cc_test( name = "gcs_filesystem_test", srcs = [ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h b/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h new file mode 100644 index 00000000000..63edbee89b8 --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h @@ -0,0 +1,111 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +// MakeCleanup(f) returns an RAII cleanup object that calls 'f' in its +// destructor. The easiest way to use MakeCleanup is with a lambda argument, +// capturing the return value in an 'auto' local variable. Most users will not +// need more sophisticated syntax than that. +// +// Example: +// void func() { +// FILE* fp = fopen("data.txt", "r"); +// if (fp == nullptr) return; +// auto fp_cleaner = gtl::MakeCleanup([fp] { fclose(fp); }); +// // No matter what, fclose(fp) will happen. +// DataObject d; +// while (ReadDataObject(fp, &d)) { +// if (d.IsBad()) { +// LOG(ERROR) << "Bad Data"; +// return; +// } +// PushGoodData(d); +// } +// } +// +// You can use Cleanup directly, instead of using MakeCleanup and auto, +// but there's rarely a reason to do that. +// +// You can call 'release()' on a Cleanup object to cancel the cleanup. + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ + +#include +#include + +#include "tensorflow/core/platform/macros.h" + +namespace tf_gcs_filesystem { + +// A move-only RAII object that calls a stored cleanup functor when +// destroyed. Cleanup is the return type of gtl::MakeCleanup(F). +template +class Cleanup { + public: + Cleanup() : released_(true), f_() {} + + template + explicit Cleanup(G&& f) // NOLINT + : f_(std::forward(f)) {} // NOLINT(build/c++11) + + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Implicitly move-constructible from any compatible Cleanup. + // The source will be released as if src.release() were called. + // A moved-from Cleanup can be safely destroyed or reassigned. + template + Cleanup(Cleanup&& src) // NOLINT + : released_(src.is_released()), f_(src.release()) {} + + // Assignment to a Cleanup object behaves like destroying it + // and making a new one in its place, analogous to unique_ptr + // semantics. + Cleanup& operator=(Cleanup&& src) { // NOLINT + if (!released_) f_(); + released_ = src.released_; + f_ = src.release(); + return *this; + } + + ~Cleanup() { + if (!released_) f_(); + } + + // Releases the cleanup function instead of running it. + // Hint: use c.release()() to run early. + F release() { + released_ = true; + return std::move(f_); + } + + bool is_released() const { return released_; } + + private: + static_assert(!std::is_reference::value, "F must not be a reference"); + + bool released_ = false; + F f_; +}; + +template ::type> +Cleanup MakeCleanup(F&& f) { + return Cleanup(std::forward(f)); +} + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_CLEANUP_H_ diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h index aa45e71e9b4..3ba7d8d7993 100644 --- a/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h @@ -1,8 +1,11 @@ /* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with the License. You may obtain a copy of the License at + http://www.apache.org/licenses/LICENSE-2.0 + Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc new file mode 100644 index 00000000000..102c7fa175c --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.cc @@ -0,0 +1,317 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ +#include "tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h" + +#include +#include +#include +#include + +#include "absl/synchronization/mutex.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/cleanup.h" + +namespace tf_gcs_filesystem { + +bool RamFileBlockCache::BlockNotStale(const std::shared_ptr& block) { + absl::MutexLock l(&block->mu); + if (block->state != FetchState::FINISHED) { + return true; // No need to check for staleness. + } + if (max_staleness_ == 0) return true; // Not enforcing staleness. + return timer_seconds_() - block->timestamp <= max_staleness_; +} + +std::shared_ptr RamFileBlockCache::Lookup( + const Key& key) { + absl::MutexLock lock(&mu_); + auto entry = block_map_.find(key); + if (entry != block_map_.end()) { + if (BlockNotStale(entry->second)) { + if (cache_stats_ != nullptr) { + cache_stats_->RecordCacheHitBlockSize(entry->second->data.size()); + } + return entry->second; + } else { + // Remove the stale block and continue. + RemoveFile_Locked(key.first); + } + } + + // Insert a new empty block, setting the bookkeeping to sentinel values + // in order to update them as appropriate. + auto new_entry = std::make_shared(); + lru_list_.push_front(key); + lra_list_.push_front(key); + new_entry->lru_iterator = lru_list_.begin(); + new_entry->lra_iterator = lra_list_.begin(); + new_entry->timestamp = timer_seconds_(); + block_map_.emplace(std::make_pair(key, new_entry)); + return new_entry; +} + +// Remove blocks from the cache until we do not exceed our maximum size. +void RamFileBlockCache::Trim() { + while (!lru_list_.empty() && cache_size_ > max_bytes_) { + RemoveBlock(block_map_.find(lru_list_.back())); + } +} + +/// Move the block to the front of the LRU list if it isn't already there. +void RamFileBlockCache::UpdateLRU(const Key& key, + const std::shared_ptr& block, + TF_Status* status) { + absl::MutexLock lock(&mu_); + if (block->timestamp == 0) { + // The block was evicted from another thread. Allow it to remain evicted. + return TF_SetStatus(status, TF_OK, ""); + } + if (block->lru_iterator != lru_list_.begin()) { + lru_list_.erase(block->lru_iterator); + lru_list_.push_front(key); + block->lru_iterator = lru_list_.begin(); + } + + // Check for inconsistent state. If there is a block later in the same file + // in the cache, and our current block is not block size, this likely means + // we have inconsistent state within the cache. Note: it's possible some + // incomplete reads may still go undetected. + if (block->data.size() < block_size_) { + Key fmax = std::make_pair(key.first, std::numeric_limits::max()); + auto fcmp = block_map_.upper_bound(fmax); + if (fcmp != block_map_.begin() && key < (--fcmp)->first) { + return TF_SetStatus(status, TF_INTERNAL, + "Block cache contents are inconsistent."); + } + } + + Trim(); + + return TF_SetStatus(status, TF_OK, ""); +} + +void RamFileBlockCache::MaybeFetch(const Key& key, + const std::shared_ptr& block, + TF_Status* status) { + bool downloaded_block = false; + auto reconcile_state = MakeCleanup([this, &downloaded_block, &key, &block] { + // Perform this action in a cleanup callback to avoid locking mu_ after + // locking block->mu. + if (downloaded_block) { + absl::MutexLock l(&mu_); + // Do not update state if the block is already to be evicted. + if (block->timestamp != 0) { + // Use capacity() instead of size() to account for all memory + // used by the cache. + cache_size_ += block->data.capacity(); + // Put to beginning of LRA list. + lra_list_.erase(block->lra_iterator); + lra_list_.push_front(key); + block->lra_iterator = lra_list_.begin(); + block->timestamp = timer_seconds_(); + } + } + }); + // Loop until either block content is successfully fetched, or our request + // encounters an error. + absl::MutexLock l(&block->mu); + TF_SetStatus(status, TF_OK, ""); + while (true) { + switch (block->state) { + case FetchState::ERROR: + // TF_FALLTHROUGH_INTENDED + case FetchState::CREATED: + block->state = FetchState::FETCHING; + block->mu.Unlock(); // Release the lock while making the API call. + block->data.clear(); + block->data.resize(block_size_, 0); + size_t bytes_transferred; + block_fetcher_(key.first, key.second, block_size_, block->data.data(), + &bytes_transferred, status); + if (cache_stats_ != nullptr) { + cache_stats_->RecordCacheMissBlockSize(bytes_transferred); + } + block->mu.Lock(); // Reacquire the lock immediately afterwards + if (TF_GetCode(status) == TF_OK) { + block->data.resize(bytes_transferred, 0); + // Shrink the data capacity to the actual size used. + // NOLINTNEXTLINE: shrink_to_fit() may not shrink the capacity. + std::vector(block->data).swap(block->data); + downloaded_block = true; + block->state = FetchState::FINISHED; + } else { + block->state = FetchState::ERROR; + } + block->cond_var.SignalAll(); + return; + case FetchState::FETCHING: + block->cond_var.WaitWithTimeout(&block->mu, absl::Minutes(1)); + if (block->state == FetchState::FINISHED) { + return TF_SetStatus(status, TF_OK, ""); + } + // Re-loop in case of errors. + break; + case FetchState::FINISHED: + return TF_SetStatus(status, TF_OK, ""); + } + } + return TF_SetStatus( + status, TF_INTERNAL, + "Control flow should never reach the end of RamFileBlockCache::Fetch."); +} + +void RamFileBlockCache::Read(const std::string& filename, size_t offset, + size_t n, char* buffer, size_t* bytes_transferred, + TF_Status* status) { + *bytes_transferred = 0; + if (n == 0) { + return TF_SetStatus(status, TF_OK, ""); + } + if (!IsCacheEnabled() || (n > max_bytes_)) { + // The cache is effectively disabled, so we pass the read through to the + // fetcher without breaking it up into blocks. + return block_fetcher_(filename, offset, n, buffer, bytes_transferred, + status); + } + // Calculate the block-aligned start and end of the read. + size_t start = block_size_ * (offset / block_size_); + size_t finish = block_size_ * ((offset + n) / block_size_); + if (finish < offset + n) { + finish += block_size_; + } + size_t total_bytes_transferred = 0; + // Now iterate through the blocks, reading them one at a time. + for (size_t pos = start; pos < finish; pos += block_size_) { + Key key = std::make_pair(filename, pos); + // Look up the block, fetching and inserting it if necessary, and update the + // LRU iterator for the key and block. + std::shared_ptr block = Lookup(key); + if (!block) { + std::cerr << "No block for key " << key.first << "@" << key.second; + abort(); + } + MaybeFetch(key, block, status); + if (TF_GetCode(status) != TF_OK) return; + UpdateLRU(key, block, status); + if (TF_GetCode(status) != TF_OK) return; + // Copy the relevant portion of the block into the result buffer. + const auto& data = block->data; + if (offset >= pos + data.size()) { + // The requested offset is at or beyond the end of the file. This can + // happen if `offset` is not block-aligned, and the read returns the last + // block in the file, which does not extend all the way out to `offset`. + *bytes_transferred = total_bytes_transferred; + std::stringstream os; + os << "EOF at offset " << offset << " in file " << filename + << " at position " << pos << " with data size " << data.size(); + return TF_SetStatus(status, TF_OUT_OF_RANGE, std::move(os).str().c_str()); + } + auto begin = data.begin(); + if (offset > pos) { + // The block begins before the slice we're reading. + begin += offset - pos; + } + auto end = data.end(); + if (pos + data.size() > offset + n) { + // The block extends past the end of the slice we're reading. + end -= (pos + data.size()) - (offset + n); + } + if (begin < end) { + size_t bytes_to_copy = end - begin; + memcpy(&buffer[total_bytes_transferred], &*begin, bytes_to_copy); + total_bytes_transferred += bytes_to_copy; + } + if (data.size() < block_size_) { + // The block was a partial block and thus signals EOF at its upper bound. + break; + } + } + *bytes_transferred = total_bytes_transferred; + return TF_SetStatus(status, TF_OK, ""); +} + +bool RamFileBlockCache::ValidateAndUpdateFileSignature( + const std::string& filename, int64_t file_signature) { + absl::MutexLock lock(&mu_); + auto it = file_signature_map_.find(filename); + if (it != file_signature_map_.end()) { + if (it->second == file_signature) { + return true; + } + // Remove the file from cache if the signatures don't match. + RemoveFile_Locked(filename); + it->second = file_signature; + return false; + } + file_signature_map_[filename] = file_signature; + return true; +} + +size_t RamFileBlockCache::CacheSize() const { + absl::MutexLock lock(&mu_); + return cache_size_; +} + +void RamFileBlockCache::Prune() { + while (!stop_pruning_thread_.WaitForNotificationWithTimeout( + absl::Microseconds(1000000))) { + absl::MutexLock lock(&mu_); + uint64_t now = timer_seconds_(); + while (!lra_list_.empty()) { + auto it = block_map_.find(lra_list_.back()); + if (now - it->second->timestamp <= max_staleness_) { + // The oldest block is not yet expired. Come back later. + break; + } + // We need to make a copy of the filename here, since it could otherwise + // be used within RemoveFile_Locked after `it` is deleted. + RemoveFile_Locked(std::string(it->first.first)); + } + } +} + +void RamFileBlockCache::Flush() { + absl::MutexLock lock(&mu_); + block_map_.clear(); + lru_list_.clear(); + lra_list_.clear(); + cache_size_ = 0; +} + +void RamFileBlockCache::RemoveFile(const std::string& filename) { + absl::MutexLock lock(&mu_); + RemoveFile_Locked(filename); +} + +void RamFileBlockCache::RemoveFile_Locked(const std::string& filename) { + Key begin = std::make_pair(filename, 0); + auto it = block_map_.lower_bound(begin); + while (it != block_map_.end() && it->first.first == filename) { + auto next = std::next(it); + RemoveBlock(it); + it = next; + } +} + +void RamFileBlockCache::RemoveBlock(BlockMap::iterator entry) { + // This signals that the block is removed, and should not be inadvertently + // reinserted into the cache in UpdateLRU. + entry->second->timestamp = 0; + lru_list_.erase(entry->second->lru_iterator); + lra_list_.erase(entry->second->lra_iterator); + cache_size_ -= entry->second->data.capacity(); + block_map_.erase(entry); +} + +} // namespace tf_gcs_filesystem diff --git a/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h new file mode 100644 index 00000000000..a33ba9d3bdc --- /dev/null +++ b/tensorflow/c/experimental/filesystem/plugins/gcs/ram_file_block_cache.h @@ -0,0 +1,266 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ +#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ + +#include +#include +#include +#include +#include +#include +#include + +#include "absl/base/thread_annotations.h" +#include "absl/synchronization/mutex.h" +#include "absl/synchronization/notification.h" +#include "tensorflow/c/env.h" +#include "tensorflow/c/experimental/filesystem/plugins/gcs/file_block_cache.h" +#include "tensorflow/c/tf_status.h" + +namespace tf_gcs_filesystem { + +/// \brief An LRU block cache of file contents, keyed by {filename, offset}. +/// +/// This class should be shared by read-only random access files on a remote +/// filesystem (e.g. GCS). +class RamFileBlockCache : public FileBlockCache { + public: + /// The callback executed when a block is not found in the cache, and needs to + /// be fetched from the backing filesystem. This callback is provided when the + /// cache is constructed. The `status` should be `TF_OK` as long as the + /// read from the remote filesystem succeeded (similar to the semantics of the + /// read(2) system call). + typedef std::function + BlockFetcher; + + RamFileBlockCache(size_t block_size, size_t max_bytes, uint64_t max_staleness, + BlockFetcher block_fetcher, + std::function timer_seconds) + : block_size_(block_size), + max_bytes_(max_bytes), + max_staleness_(max_staleness), + block_fetcher_(block_fetcher), + timer_seconds_(timer_seconds), + pruning_thread_(nullptr, [](TF_Thread* thread) { TF_JoinThread(thread); }) { + if (max_staleness_ > 0) { + TF_ThreadOptions thread_options; + TF_DefaultThreadOptions(&thread_options); + pruning_thread_.reset( + TF_StartThread(&thread_options, "TF_prune_FBC", PruneThread, this)); + } + std::cout << "GCS file block cache is " + << (IsCacheEnabled() ? "enabled" : "disabled"); + } + + ~RamFileBlockCache() override { + if (pruning_thread_) { + stop_pruning_thread_.Notify(); + // Destroying pruning_thread_ will block until Prune() receives the above + // notification and returns. + pruning_thread_.reset(); + } + } + + /// Read `n` bytes from `filename` starting at `offset` into `buffer`. This + /// method will set `status` to: + /// + /// 1) The error from the remote filesystem, if the read from the remote + /// filesystem failed. + /// 2) `TF_FAILED_PRECONDITION` if the read from the remote filesystem + /// succeeded, + /// but the read returned a partial block, and the LRU cache contained a + /// block at a higher offset (indicating that the partial block should have + /// been a full block). + /// 3) `TF_OUT_OF_RANGE` if the read from the remote filesystem succeeded, but + /// the file contents do not extend past `offset` and thus nothing was + /// placed in `out`. + /// 4) `TF_OK` otherwise (i.e. the read succeeded, and at least one byte was + /// placed + /// in `buffer`). + /// + /// Caller is responsible for allocating memory for `buffer`. + /// `buffer` will be left unchanged in case of errors. + void Read(const std::string& filename, size_t offset, size_t n, char* buffer, + size_t* bytes_transferred, TF_Status* status) override; + + // Validate the given file signature with the existing file signature in the + // cache. Returns true if the signature doesn't change or the file doesn't + // exist before. If the signature changes, update the existing signature with + // the new one and remove the file from cache. + bool ValidateAndUpdateFileSignature(const std::string& filename, + int64_t file_signature) override + ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all cached blocks for `filename`. + void RemoveFile(const std::string& filename) override + ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all cached data. + void Flush() override ABSL_LOCKS_EXCLUDED(mu_); + + /// Accessors for cache parameters. + size_t block_size() const override { return block_size_; } + size_t max_bytes() const override { return max_bytes_; } + uint64_t max_staleness() const override { return max_staleness_; } + + /// The current size (in bytes) of the cache. + size_t CacheSize() const override ABSL_LOCKS_EXCLUDED(mu_); + + // Returns true if the cache is enabled. If false, the BlockFetcher callback + // is always executed during Read. + bool IsCacheEnabled() const override { + return block_size_ > 0 && max_bytes_ > 0; + } + + // We can not pass a lambda with capture as a function pointer to + // `TF_StartThread`, so we have to wrap `Prune` inside a static function. + static void PruneThread(void* param) { + auto ram_file_block_cache = static_cast(param); + ram_file_block_cache->Prune(); + } + + private: + /// The size of the blocks stored in the LRU cache, as well as the size of the + /// reads from the underlying filesystem. + const size_t block_size_; + /// The maximum number of bytes (sum of block sizes) allowed in the LRU cache. + const size_t max_bytes_; + /// The maximum staleness of any block in the LRU cache, in seconds. + const uint64_t max_staleness_; + /// The callback to read a block from the underlying filesystem. + const BlockFetcher block_fetcher_; + /// The callback to read timestamps. + const std::function timer_seconds_; + + /// \brief The key type for the file block cache. + /// + /// The file block cache key is a {filename, offset} pair. + typedef std::pair Key; + + /// \brief The state of a block. + /// + /// A block begins in the CREATED stage. The first thread will attempt to read + /// the block from the filesystem, transitioning the state of the block to + /// FETCHING. After completing, if the read was successful the state should + /// be FINISHED. Otherwise the state should be ERROR. A subsequent read can + /// re-fetch the block if the state is ERROR. + enum class FetchState { + CREATED, + FETCHING, + FINISHED, + ERROR, + }; + + /// \brief A block of a file. + /// + /// A file block consists of the block data, the block's current position in + /// the LRU cache, the timestamp (seconds since epoch) at which the block + /// was cached, a coordination lock, and state & condition variables. + /// + /// Thread safety: + /// The iterator and timestamp fields should only be accessed while holding + /// the block-cache-wide mu_ instance variable. The state variable should only + /// be accessed while holding the Block's mu lock. The data vector should only + /// be accessed after state == FINISHED, and it should never be modified. + /// + /// In order to prevent deadlocks, never grab the block-cache-wide mu_ lock + /// AFTER grabbing any block's mu lock. It is safe to grab mu without locking + /// mu_. + struct Block { + /// The block data. + std::vector data; + /// A list iterator pointing to the block's position in the LRU list. + std::list::iterator lru_iterator; + /// A list iterator pointing to the block's position in the LRA list. + std::list::iterator lra_iterator; + /// The timestamp (seconds since epoch) at which the block was cached. + uint64_t timestamp; + /// Mutex to guard state variable + absl::Mutex mu; + /// The state of the block. + FetchState state ABSL_GUARDED_BY(mu) = FetchState::CREATED; + /// Wait on cond_var if state is FETCHING. + absl::CondVar cond_var; + }; + + /// \brief The block map type for the file block cache. + /// + /// The block map is an ordered map from Key to Block. + typedef std::map> BlockMap; + + /// Prune the cache by removing files with expired blocks. + void Prune() ABSL_LOCKS_EXCLUDED(mu_); + + bool BlockNotStale(const std::shared_ptr& block) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Look up a Key in the block cache. + std::shared_ptr Lookup(const Key& key) ABSL_LOCKS_EXCLUDED(mu_); + + void MaybeFetch(const Key& key, const std::shared_ptr& block, + TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_); + + /// Trim the block cache to make room for another entry. + void Trim() ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Update the LRU iterator for the block at `key`. + void UpdateLRU(const Key& key, const std::shared_ptr& block, + TF_Status* status) ABSL_LOCKS_EXCLUDED(mu_); + + /// Remove all blocks of a file, with mu_ already held. + void RemoveFile_Locked(const std::string& filename) + ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// Remove the block `entry` from the block map and LRU list, and update the + /// cache size accordingly. + void RemoveBlock(BlockMap::iterator entry) ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_); + + /// The cache pruning thread that removes files with expired blocks. + std::unique_ptr> pruning_thread_; + + /// Notification for stopping the cache pruning thread. + absl::Notification stop_pruning_thread_; + + /// Guards access to the block map, LRU list, and cached byte count. + mutable absl::Mutex mu_; + + /// The block map (map from Key to Block). + BlockMap block_map_ ABSL_GUARDED_BY(mu_); + + /// The LRU list of block keys. The front of the list identifies the most + /// recently accessed block. + std::list lru_list_ ABSL_GUARDED_BY(mu_); + + /// The LRA (least recently added) list of block keys. The front of the list + /// identifies the most recently added block. + /// + /// Note: blocks are added to lra_list_ only after they have successfully been + /// fetched from the underlying block store. + std::list lra_list_ ABSL_GUARDED_BY(mu_); + + /// The combined number of bytes in all of the cached blocks. + size_t cache_size_ ABSL_GUARDED_BY(mu_) = 0; + + // A filename->file_signature map. + std::map file_signature_map_ ABSL_GUARDED_BY(mu_); +}; + +} // namespace tf_gcs_filesystem + +#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_RAM_FILE_BLOCK_CACHE_H_ From 1f2de130a0ba3ffb0985028cd35ed9ffd8e464a2 Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Tue, 7 Jul 2020 14:34:52 -0400 Subject: [PATCH 17/88] Update code and BUILD file. --- tensorflow/python/keras/benchmarks/BUILD | 1 - .../python/keras/benchmarks/keras_cpu_benchmark_test.py | 5 ++--- 2 files changed, 2 insertions(+), 4 deletions(-) diff --git a/tensorflow/python/keras/benchmarks/BUILD b/tensorflow/python/keras/benchmarks/BUILD index 682d1b24be8..18d72398344 100644 --- a/tensorflow/python/keras/benchmarks/BUILD +++ b/tensorflow/python/keras/benchmarks/BUILD @@ -31,7 +31,6 @@ py_test( deps = [ ":benchmark_util", "//tensorflow:tensorflow_py", - "//tensorflow/python/keras", "//third_party/py/numpy", ], ) diff --git a/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py b/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py index 63eedfcd11c..9d9d4c3edd7 100644 --- a/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py +++ b/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py @@ -23,7 +23,6 @@ import six import tensorflow as tf from tensorflow.python.platform import benchmark -from tensorflow.python.platform import test from tensorflow.python.keras.benchmarks import benchmark_util # Loss function and optimizer. @@ -32,7 +31,7 @@ _OPTIMIZER = 'rmsprop' class KerasModelCPUBenchmark( - six.with_metaclass(benchmark.ParameterizedBenchmark, test.Benchmark)): + six.with_metaclass(benchmark.ParameterizedBenchmark, tf.test.Benchmark)): """Required Arguments for measure_performance: x: Input data, it could be Numpy or load from tfds. y: Target data. If `x` is a dataset, generator instance, @@ -137,4 +136,4 @@ class KerasModelCPUBenchmark( if __name__ == '__main__': - test.main() \ No newline at end of file + tf.test.main() \ No newline at end of file From 41a7a73cda7da9e95ca9324bbc61bef54e0f20db Mon Sep 17 00:00:00 2001 From: Xingyu Long Date: Tue, 7 Jul 2020 15:30:26 -0400 Subject: [PATCH 18/88] Fix minor error usage in LSTM part. --- tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py b/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py index 9d9d4c3edd7..24ca07cdb71 100644 --- a/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py +++ b/tensorflow/python/keras/benchmarks/keras_cpu_benchmark_test.py @@ -123,7 +123,7 @@ class KerasModelCPUBenchmark( """Benchmark for LSTM model on synthetic imdb review dataset.""" lstm_x = np.random.randint(0, 1999, size=(2500, 100)) lstm_y = np.random.random((2500, 1)) - results = benchmark_util.measure_performance(self._imdb_lstm(), + results = benchmark_util.measure_performance(self._imdb_lstm, x=lstm_x, y=lstm_y, batch_size=batch_size, From 2d67a4c8a5604849b0853d775b5cf93d5a3c30eb Mon Sep 17 00:00:00 2001 From: Gaurav Singh Date: Thu, 4 Jun 2020 12:19:57 -0400 Subject: [PATCH 19/88] [compiler] Fix segmentation faulit in segment graph Signed-off-by: Gaurav Singh Change the log message Signed-off-by: Gaurav Singh --- tensorflow/compiler/tf2tensorrt/segment/segment.cc | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index d9080b6f69a..5bcff4bb9e1 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -721,6 +721,10 @@ Status SegmentGraph(const Graph* tf_graph, std::vector> node_segments; for (int i = 0; i < graph->num_node_ids(); ++i) { SimpleNode* node = graph->FindNodeId(i); + if (!node) { + VLOG(1) << "Node " << i << " doesn't exist in the graph"; + continue; + } auto exclude_node = [&](absl::string_view reason) { VLOG(1) << "Not a TF-TRT candidate, " << "(Op type: " << node->tf_node()->type_string() << "), " From 82e12bf3876a68a5c1cafe8bb622fb61e393e149 Mon Sep 17 00:00:00 2001 From: Tim Shen Date: Tue, 7 Jul 2020 20:55:32 -0700 Subject: [PATCH 20/88] [XLA/GPU] Decouple hlo_ordering from thunk_schedule. The plumbing before this goes like this: * hlo_ordering -> buffer_assignment * buffer_assignment -> ir_emitter_unnested (DFS order) -> thunks * Apply hlo_ordering to thunks -> thunk_schedule After: * hlo_ordering -> buffer_assignment * buffer_assignment -> ir_emitter_unnested (hlo_ordering) -> thunks * thunks -> thunk_schedule (order unchanged) The idea is that since thunks are scheduled to the a certain total order anyway, just use that order to invoke the emitter. It saves an extra schedule, but most importantly, it removes uses of Thunk::hlo_instruction(), which makes MLIR/GPU transition easier. PiperOrigin-RevId: 320117281 Change-Id: I0ee9ff14e71869ea09d6223ae10448317298096f --- tensorflow/compiler/xla/service/gpu/gpu_compiler.cc | 8 +++++--- .../compiler/xla/service/gpu/thunk_schedule.cc | 13 +++++-------- .../compiler/xla/service/gpu/thunk_schedule.h | 3 +-- .../xla/service/mlir_gpu/lhlo_dialect_emitter.cc | 6 ++++-- .../xla/service/mlir_gpu/lhlo_dialect_emitter.h | 3 ++- .../xla/service/mlir_gpu/mlir_compiler_impl.cc | 6 +++--- 6 files changed, 20 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc index 2b31099d26f..f5ed7e3a114 100644 --- a/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc +++ b/tensorflow/compiler/xla/service/gpu/gpu_compiler.cc @@ -518,8 +518,11 @@ static Status CompileModuleToLlvmIrImpl( { XLA_SCOPED_LOGGING_TIMER("GpuCompiler::RunBackend - IR emission"); - TF_RETURN_IF_ERROR(entry_computation->Accept(&ir_emitter)); + TF_RETURN_IF_ERROR(entry_computation->AcceptOrdered( + &ir_emitter, (*hlo_schedule)->ThunkLaunchOrder())); } + // The order of `thunk_sequence` corresponds to + // `hlo_schedule->ThunkLaunchOrder()`. *thunk_sequence = ir_emitter.ConsumeThunkSequence(); return Status::OK(); } @@ -610,8 +613,7 @@ StatusOr> GpuCompiler::RunBackend( gpu_version, stream_exec)); auto thunk_schedule = absl::make_unique( - std::move(thunk_sequence), std::move(stream_assignment), - hlo_schedule->ThunkLaunchOrder()); + std::move(thunk_sequence), std::move(stream_assignment)); if (DumpingEnabledForHloModule(*module)) { DumpToFileInDirOrStdout(*module, "", "thunk_schedule", thunk_schedule->ToString()); diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc index daa5f33e560..a91466e5c5f 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.cc @@ -49,21 +49,18 @@ void ThunkSchedule::AddDependenciesOnTransitiveOperands( ThunkSchedule::ThunkSchedule( std::unique_ptr thunks, - std::unique_ptr stream_assignment, - const std::vector& hlo_total_order) + std::unique_ptr stream_assignment) : thunks_(std::move(thunks)), stream_assignment_(std::move(stream_assignment)) { + for (auto& thunk : *thunks_) { + thunk_total_order_.push_back(thunk.get()); + } + absl::flat_hash_map hlo_to_thunk; for (const auto& thunk : *thunks_) { InsertOrDie(&hlo_to_thunk, thunk->hlo_instruction(), thunk.get()); } - for (HloInstruction* hlo : hlo_total_order) { - if (Thunk** thunk = tensorflow::gtl::FindOrNull(hlo_to_thunk, hlo)) { - thunk_total_order_.push_back(*thunk); - } - } - for (const Thunk* thunk : thunk_total_order_) { const auto* dst = thunk->hlo_instruction(); CHECK(stream_assignment_->HasStreamAssigned(*dst)); diff --git a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h index 549378debd5..73da708aa3d 100644 --- a/tensorflow/compiler/xla/service/gpu/thunk_schedule.h +++ b/tensorflow/compiler/xla/service/gpu/thunk_schedule.h @@ -47,8 +47,7 @@ namespace gpu { class ThunkSchedule { public: ThunkSchedule(std::unique_ptr thunks, - std::unique_ptr stream_assignment, - const std::vector& hlo_total_order); + std::unique_ptr stream_assignment); // Returns the total order of executing all the thunks. const std::vector& TotalOrder() const { return thunk_total_order_; } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index a65096b7eac..3654271da53 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -226,8 +226,10 @@ absl::string_view LhloDialectEmitter::platform_name() const { return platform_->Name(); } -Status LhloDialectEmitter::EmitComputation(const HloComputation& computation) { - return computation.root_instruction()->Accept(this); +Status LhloDialectEmitter::EmitComputation( + const HloComputation& computation, + absl::Span ordering) { + return computation.AcceptOrdered(this, ordering); } StatusOr LhloDialectEmitter::CreateFunction( diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h index 185c1e13bb7..2fe1947d625 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.h @@ -47,7 +47,8 @@ class LhloDialectEmitter : public DfsHloVisitorWithDefault, ::mlir::ModuleOp mlir_module); ~LhloDialectEmitter() override = default; - Status EmitComputation(const HloComputation& computation); + Status EmitComputation(const HloComputation& computation, + absl::Span ordering); // The following methods implement the DfsHloVisitor interface. // diff --git a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc index beabc99a173..a2bee43a0f8 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/mlir_compiler_impl.cc @@ -489,7 +489,8 @@ StatusOr> MlirCompilerImpl::RunBackend( stream_exec->platform(), *mlir_module); TF_RETURN_IF_ERROR(lhlo_emitter.EmitComputation( - *emission_context.getHloModule()->entry_computation())); + *emission_context.getHloModule()->entry_computation(), + hlo_schedule->ThunkLaunchOrder())); TF_RETURN_IF_ERROR( module_hook_.invoke(IRHook::LoweringStage::LHLO, *mlir_module)); @@ -539,8 +540,7 @@ StatusOr> MlirCompilerImpl::RunBackend( gpu::PtxOptsFromConfig(config))); auto thunk_schedule = absl::make_unique( - std::move(thunk_sequence), std::move(stream_assignment), - hlo_schedule->ThunkLaunchOrder()); + std::move(thunk_sequence), std::move(stream_assignment)); if (DumpingEnabledForHloModule(*emission_context.getHloModule())) { DumpToFileInDirOrStdout(*emission_context.getHloModule(), "", From 3dda4182aa068ed37304ad77014faca5c0935386 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Tue, 7 Jul 2020 20:59:03 -0700 Subject: [PATCH 21/88] [XLA:CPU] Teach dot_op_emitter how to tile&vectorize linalg matmuls And turn them on by default. This is on-par with the existing emitter, sometimes better and unlocks more potential. The strategy classes are duplicated right now, but I expect them to graduate to mlir core soon. I'm planning to remove the custom LLVM IR emitters if this turns out to be stable enough. PiperOrigin-RevId: 320117625 Change-Id: I3580df9990ca2a022a49327fa819c2086fd1e2ed --- tensorflow/compiler/xla/service/cpu/BUILD | 26 +- .../compiler/xla/service/cpu/cpu_options.cc | 7 + .../xla/service/cpu/dot_op_emitter.cc | 41 +-- .../compiler/xla/service/cpu/mlir_emitter.cc | 8 +- .../cpu/mlir_matmul_codegen_strategy.cc | 269 ------------------ .../cpu/mlir_matmul_codegen_strategy.h | 188 ------------ .../xla/service/cpu/target_machine_features.h | 12 - .../cpu/target_machine_features_fake.h | 4 - 8 files changed, 14 insertions(+), 541 deletions(-) delete mode 100644 tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc delete mode 100644 tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index b9e10bfb083..102753b882f 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -471,7 +471,6 @@ cc_library( ":cpu_runtime", ":ir_emission_utils", ":mlir_emitter", - ":mlir_matmul_codegen_strategy", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -1103,33 +1102,12 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:IPO", "@llvm-project//llvm:Linker", - "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", + "@llvm-project//mlir:LLVMTransforms", + "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TargetLLVMIR", - "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorToLLVM", ], ) - -cc_library( - name = "mlir_matmul_codegen_strategy", - srcs = ["mlir_matmul_codegen_strategy.cc"], - hdrs = ["mlir_matmul_codegen_strategy.h"], - deps = [ - "@llvm-project//llvm:Support", - "@llvm-project//mlir:Affine", - "@llvm-project//mlir:Analysis", - "@llvm-project//mlir:IR", - "@llvm-project//mlir:LinalgOps", - "@llvm-project//mlir:LinalgTransforms", - "@llvm-project//mlir:Pass", - "@llvm-project//mlir:SCFDialect", - "@llvm-project//mlir:StandardOps", - "@llvm-project//mlir:Support", - "@llvm-project//mlir:Transforms", - "@llvm-project//mlir:VectorOps", - "@llvm-project//mlir:VectorToSCF", - ], -) diff --git a/tensorflow/compiler/xla/service/cpu/cpu_options.cc b/tensorflow/compiler/xla/service/cpu/cpu_options.cc index ff654c83d61..c0222010fd9 100644 --- a/tensorflow/compiler/xla/service/cpu/cpu_options.cc +++ b/tensorflow/compiler/xla/service/cpu/cpu_options.cc @@ -25,6 +25,7 @@ const char* const kXlaOptimizeForSizeCpuOption = "xla_cpu_optimize_for_size"; const char* const kLlvmIrDotTilingFactor = "xla_llvm_dot_tiling_factor"; const char* const kXlaForceEnableExperimentalLlvmIrGemm = "xla_force_enable_experimental_llvm_ir_gemm"; +const char* const kXlaUseLinalgForDot = "xla_use_linalg_for_dot"; const char* const kLlvmIrGemmTileSize = "xla_llvm_ir_gemm_tile_size"; } // namespace @@ -63,6 +64,12 @@ bool ForceEnableExperimentalLlvmIrGemm(const HloModuleConfig& config) { return extra_options_map.count(kXlaForceEnableExperimentalLlvmIrGemm) > 0; } +bool UseLinalgForDot(const HloModuleConfig& config) { + const auto& extra_options_map = + config.debug_options().xla_backend_extra_options(); + return extra_options_map.count(kXlaUseLinalgForDot) > 0; +} + static absl::string_view RemoveSuffix(absl::string_view str, absl::string_view suffix) { CHECK_GE(str.size(), suffix.size()); diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 574d83c68c8..72f3a4dfac7 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -31,12 +31,10 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project -#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" -#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -204,20 +202,6 @@ class DotOpEmitter { .value_or(kDefaultTileSize); } - std::array GetMlirGemmTileSize() const { - // Tile by 4 x registers x register size. This was picked by running - // small matmuls on Haswell and Skylake. There's a lot of room for - // improvement here. - constexpr int64_t kDefaultTileSizeForM = 4; - int64_t elements_per_register = - target_machine_features_.vector_register_num_elements( - *b_->GetInsertBlock()->getParent(), - dot_info_.result_shape.element_type()); - int64_t num_registers = target_machine_features_.vector_register_count( - *b_->GetInsertBlock()->getParent()); - return {{kDefaultTileSizeForM, num_registers, elements_per_register}}; - } - DotInfo dot_info_; string dot_hlo_name_; const llvm_ir::IrArray& target_array_; @@ -266,7 +250,6 @@ Status DotOpEmitter::EmitLinalgMatmul() { absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_", dot_info_.lhs_shape.ToString(true), "_", dot_info_.rhs_shape.ToString(true)); - return EmitMlirFuncAndCall( mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr, operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) { @@ -276,27 +259,6 @@ Status DotOpEmitter::EmitLinalgMatmul() { mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{}, mlir::ValueRange{b, c, a}); mlir::edsc::intrinsics::std_ret(); - - mlir::linalg::LinalgTilingOptions tilingOptions; - tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize()); - int64 alignment = - target_machine_features_.minimum_alignment_for_allocation( - ShapeUtil::ByteSizeOf(dot_info_.result_shape)); - mlir_strategy::MatmulCodegenStrategy strategy; - strategy.tile(tilingOptions) - .promote( - mlir::linalg::LinalgPromotionOptions() - .setAlignment(alignment) - .setUseFullTileBuffersByDefault(true) - .setUseAlloca(true)) - .vectorize() - .setVectorTransformsOptions( - mlir::vector::VectorTransformsOptions() - .setVectorTransformsOptions( - mlir::vector::VectorContractLowering::OuterProduct)) - .setVectorTransferToSCFOptions( - mlir::VectorTransferToSCFOptions().setUnroll(true)); - strategy.transform(function); }); } @@ -986,8 +948,7 @@ DotImplementationStrategy GetDotImplementationStrategy( if (IsAlignedGemm(dot_info, target_machine_features)) { if (CanEmitTiledLlvmIrGemm(config, dot_info, target_machine_features)) { - return primitive_util::IsFloatingPointType( - dot_info.result_shape.element_type()) + return options::UseLinalgForDot(config) ? DotImplementationStrategy::kLinalgMatmul : DotImplementationStrategy::kTiledLlvmIrGemm; } diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc index d17f4671327..e7d52c288d5 100644 --- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -17,14 +17,14 @@ limitations under the License. #include "llvm/Linker/Linker.h" #include "llvm/Transforms/IPO/Internalize.h" -#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project +#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project +#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Target/LLVMIR.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_utils.h" namespace xla { @@ -35,9 +35,9 @@ namespace { std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { mlir::PassManager manager(module->getContext()); manager.addPass(mlir::createConvertLinalgToLoopsPass()); - manager.addPass(mlir::createLowerAffinePass()); - manager.addPass(mlir::createLowerToCFGPass()); + manager.addPass(mlir::createConvertLinalgToLLVMPass()); manager.addPass(mlir::createConvertVectorToLLVMPass()); + manager.addPass(mlir::createLowerToLLVMPass()); CHECK(succeeded(manager.run(*module))); return mlir::translateModuleToLLVMIR(*module); } diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc deleted file mode 100644 index ea89071a967..00000000000 --- a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc +++ /dev/null @@ -1,269 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" - -#include "llvm/ADT/STLExtras.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/Support/Debug.h" -#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project -#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project -#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project -#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project -#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project -#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project -#include "mlir/IR/AffineExpr.h" // from @llvm-project -#include "mlir/IR/Attributes.h" // from @llvm-project -#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project -#include "mlir/IR/Dominance.h" // from @llvm-project -#include "mlir/IR/MLIRContext.h" // from @llvm-project -#include "mlir/IR/OperationSupport.h" // from @llvm-project -#include "mlir/IR/PatternMatch.h" // from @llvm-project -#include "mlir/IR/StandardTypes.h" // from @llvm-project -#include "mlir/IR/Value.h" // from @llvm-project -#include "mlir/IR/Visitors.h" // from @llvm-project -#include "mlir/Pass/Pass.h" // from @llvm-project -#include "mlir/Pass/PassManager.h" // from @llvm-project -#include "mlir/Support/LogicalResult.h" // from @llvm-project -#include "mlir/Transforms/LoopUtils.h" // from @llvm-project -#include "mlir/Transforms/Passes.h" // from @llvm-project - -// TODO(kramerb): Remove this once strategy is in mlir core. - -using namespace mlir; // NOLINT -using namespace mlir::linalg; // NOLINT - -#define DEBUG_TYPE "matmul-codegen-strategy" - -namespace xla { -namespace cpu { -namespace mlir_strategy { - -//===----------------------------------------------------------------------===// -// TODO: Cleanup and upstream these to go into core. Please ignore for now ! -//===----------------------------------------------------------------------===// -static void hoistRedundantCopies(FuncOp func) { - bool changed = true; - while (changed) { - changed = false; - func.walk([&](linalg::FillOp op) { - auto loop = op.getParentOfType(); - if (!loop) return; - - for (auto operand : op.getOperands()) - if (!loop.isDefinedOutsideOfLoop(operand)) return; - - // Hoist fill before. - op.getOperation()->moveBefore(loop); - changed = true; - }); - - func.walk([&](linalg::CopyOp op) { - auto loop = op.getParentOfType(); - if (!loop) return; - - for (auto operand : op.getOperands()) - if (!loop.isDefinedOutsideOfLoop(operand)) return; - - Value sourceView = op.getInput(0); - while (auto subViewOp = sourceView.getDefiningOp()) - sourceView = subViewOp.getViewSource(); - - // Source traces back to a block argument. - if (sourceView.isa()) { - op.getOperation()->moveBefore(loop); - } else { - assert(sourceView.getDefiningOp() || - sourceView.getDefiningOp() || - sourceView.getDefiningOp()); - op.getOperation()->moveAfter(loop); - } - changed = true; - }); - } -} - -/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing: -/// `%lb + %step * new_dim` where -/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an -/// AffineDimExpr depending on whether the value is constant or not. -/// 2. the AffineExpr for %step is either an AffineConstantExpr or an -/// AffineSymbolExpr depending on whether the value is constant or not. -/// -static void substitute(scf::ForOp forOp, SmallVectorImpl &exprs, - SmallVectorImpl &dims, - SmallVectorImpl &symbols) { - MLIRContext *ctx = forOp.getContext(); - auto lbConstant = forOp.lowerBound().getDefiningOp(); - AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx) - : getAffineDimExpr(dims.size(), ctx); - - auto stepConstant = forOp.step().getDefiningOp(); - AffineExpr step = stepConstant - ? getAffineConstantExpr(stepConstant.getValue(), ctx) - : getAffineSymbolExpr(symbols.size(), ctx); - - if (!lbConstant) dims.push_back(forOp.lowerBound()); - if (!stepConstant) symbols.push_back(forOp.step()); - exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx)); - - auto ubConstant = forOp.upperBound().getDefiningOp(); - AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx) - : getAffineDimExpr(dims.size(), ctx); - if (!ubConstant) dims.push_back(forOp.upperBound()); - exprs.push_back(ub); - - dims.push_back(forOp.getInductionVar()); -} - -/// Traverse the . -static void substitute(AffineMinOp minOp, SmallVectorImpl &exprs, - SmallVectorImpl &dims, - SmallVectorImpl &symbols) { - MLIRContext *ctx = minOp.getContext(); - for (Value v : minOp.getDimOperands()) { - if (auto forOp = scf::getForInductionVarOwner(v)) { - substitute(forOp, exprs, dims, symbols); - continue; - } - if (auto parentMinOp = v.getDefiningOp()) { - substitute(parentMinOp, exprs, dims, symbols); - continue; - } - exprs.push_back(getAffineDimExpr(dims.size(), ctx)); - dims.push_back(v); - } -} - -/// Perform folding of chains of AffineMinOp. -struct AffineMinCanonicalizationPattern : public OpRewritePattern { - using OpRewritePattern::OpRewritePattern; - - LogicalResult matchAndRewrite(AffineMinOp minOp, - PatternRewriter &rewriter) const override; -}; - -LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite( - AffineMinOp minOp, PatternRewriter &rewriter) const { - LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: " - << *minOp.getOperation() << "\n"); - - int64_t min = std::numeric_limits::max(); - for (auto e : minOp.map().getResults()) - if (auto cstExpr = e.dyn_cast()) - min = std::min(min, cstExpr.getValue()); - if (min == std::numeric_limits::max()) return failure(); - - SmallVector exprs; - SmallVector dims, symbols; - substitute(minOp, exprs, dims, symbols); - - SmallVector operands = dims; - operands.append(symbols.begin(), symbols.end()); - - MLIRContext *ctx = minOp.getContext(); - auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx); - LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n"); - - SmallVector modExprs; - for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) - modExprs.push_back(getAffineDimExpr(idx, ctx) % min); - map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map); - canonicalizeMapAndOperands(&map, &operands); - map = simplifyAffineMap(map); - - LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n"; - llvm::interleaveComma(operands, llvm::dbgs())); - - if (!llvm::all_of(map.getResults(), [](AffineExpr e) { - if (auto cst = e.dyn_cast()) - return cst.getValue() == 0; - return false; - })) - return failure(); - - rewriter.replaceOpWithNewOp(minOp, min); - return success(); -} -//===----------------------------------------------------------------------===// -// END TODO -//===----------------------------------------------------------------------===// - -void MatmulCodegenStrategy::transform(FuncOp func) const { - MLIRContext *context = func.getContext(); - // Emplace patterns one at a time while also maintaining a simple chained - // state transition. - unsigned stepCount = 0; - SmallVector stage1Patterns; - auto zeroState = Identifier::get(std::to_string(stepCount), context); - auto currentState = zeroState; - for (auto &t : transformation_sequence) { - auto nextState = Identifier::get(std::to_string(++stepCount), context); - auto marker = (currentState == zeroState) - ? linalg::LinalgMarker({}, nextState) - : linalg::LinalgMarker(currentState, nextState); - stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker)); - currentState = nextState; - } - - OwningRewritePatternList stage2Patterns = - linalg::getLinalgTilingCanonicalizationPatterns(context); - stage2Patterns.insert(context); - - auto stage3Transforms = [](Operation *op) { - // Some of these may be too aggressive as a stage 3 that is applied on each - // stage 1 application and may have to be split out to post staged patterns - // application (in which case they could just be passes, TBD). - PassManager pm(op->getContext()); - pm.addPass(createLoopInvariantCodeMotionPass()); - if (failed(pm.run(op->getParentOfType()))) - llvm_unreachable("Unexpected failure in cleanup pass pipeline."); - promoteSingleIterationLoops(cast(op)); - hoistViewAllocOps(cast(op)); - hoistRedundantVectorTransfers(cast(op)); - hoistRedundantCopies(cast(op)); - return success(); - }; - linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns, - stage3Transforms); - - //===--------------------------------------------------------------------===// - // Post staged patterns transforms - //===--------------------------------------------------------------------===// - // Programmatic controlled lowering of vector.contract only. - OwningRewritePatternList vectorContractLoweringPatterns; - vectorContractLoweringPatterns - .insert( - vector_transforms_options, context); - applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns); - - // Programmatic controlled lowering of vector.transfer only. - OwningRewritePatternList vectorToLoopsPatterns; - populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, - vector_to_scf_options); - applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns); -} - -} // namespace mlir_strategy -} // namespace cpu -} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h deleted file mode 100644 index 3b11b750c47..00000000000 --- a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h +++ /dev/null @@ -1,188 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ -#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ - -#include "llvm/ADT/ArrayRef.h" -#include "llvm/ADT/SmallVector.h" -#include "llvm/ADT/StringSwitch.h" -#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project -#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project -#include "mlir/Support/LLVM.h" // from @llvm-project - -// TODO(kramerb): Remove this once strategy is in mlir core. - -namespace xla { -namespace cpu { -namespace mlir_strategy { - -/// Abstract Transformation class applied in a sequence that also handles state -/// through markers. -struct Transformation { - virtual ~Transformation() = default; - virtual mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0; - mlir::linalg::LinalgMarker marker; -}; - -/// Promotion transformation enqueues a particular stage-1 pattern for -/// `Tile`with the appropriate `options`. -// TODO: variadic LinalgOpTypes. -template -struct Tile : public Transformation { - explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {} - - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList tiling_patterns; - tiling_patterns.insert>( - context, options, m); - return tiling_patterns; - } - - private: - mlir::linalg::LinalgTilingOptions options; -}; - -/// Promotion transformation enqueues a particular stage-1 pattern for -/// `Promote`with the appropriate `options`. -// TODO: variadic LinalgOpTypes. -template -struct Promote : public Transformation { - explicit Promote(mlir::linalg::LinalgPromotionOptions options) - : options(options) {} - - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList promotion_patterns; - promotion_patterns - .insert>(context, - options, m); - return promotion_patterns; - } - - private: - mlir::linalg::LinalgPromotionOptions options; -}; - -/// Vectorization transformation enqueues a particular stage-1 pattern for -/// `LinalgVectorizationPattern` as well as copy to vector -/// transfer rewrite forwarding patterns. -// TODO: variadic LinalgOpTypes. -template -struct Vectorize : public Transformation { - mlir::OwningRewritePatternList buildRewritePatterns( - mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { - mlir::OwningRewritePatternList vectorization_patterns; - // FillOp may interfere with forwarding patterns atm, so we bump up the - // priority of LinalgCopyVTRForwardingPattern / - // LinalgCopyVTWForwardingPattern. - vectorization_patterns - .insert>(context, - m); - vectorization_patterns.insert( - context, - /*benefit=*/2); - return vectorization_patterns; - } -}; - -/// Matmul-specific strategy object controls how a linalg.matmul is -/// progressively lowered. -/// The strategy uses a 3-level staged patterns strategy which allows ordering -/// transformations by using the Linalg `applyStagedPatterns` function, where: -/// 1. The first stage consists of the successive `tile`, `promote` and -/// `vectorize` patterns, applied sequentially. -/// 2. The second stage consists of common local canonicalization patterns -/// that are applied eagerly after each stage-1 pattern. -/// 3. the third stage consists of more global transformation, also applied -/// eagerly, after all stage-2 patterns. Such more global transformations -struct MatmulCodegenStrategy { - /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling - /// `options`. - template - MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) { - transformation_sequence.emplace_back(new Tile(options)); - return *this; - } - /// Conditionally append a pattern to add a level of tiling for `LinalgOpType` - /// with tiling `options`. - template - MatmulCodegenStrategy &tileIf(bool b, - mlir::linalg::LinalgTilingOptions options) { - return b ? tile(options) : *this; - } - /// Append a pattern to add a level of promotion for `LinalgOpType` with - /// promotion `options`. - template - MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) { - transformation_sequence.emplace_back(new Promote(options)); - return *this; - } - /// Conditionally append a pattern to add a level of promotion for - /// `LinalgOpType` with promotion `options`. - template - MatmulCodegenStrategy &promoteIf( - bool b, mlir::linalg::LinalgPromotionOptions options) { - return b ? promote(options) : *this; - return *this; - } - /// Append a pattern to rewrite `LinalgOpType` as a vector operation. - template - MatmulCodegenStrategy &vectorize() { - transformation_sequence.emplace_back(new Vectorize()); - return *this; - } - /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector - /// operation. - template - MatmulCodegenStrategy &vectorizeIf(bool b) { - return b ? vectorize() : *this; - return *this; - } - /// Configure the post staged-patterns late vector transformations. - MatmulCodegenStrategy &setVectorTransformsOptions( - mlir::vector::VectorTransformsOptions options) { - vector_transforms_options = options; - return *this; - } - /// Configure the post staged-patterns late vector.transfer to scf conversion. - MatmulCodegenStrategy &setVectorTransferToSCFOptions( - mlir::VectorTransferToSCFOptions options) { - vector_to_scf_options = options; - return *this; - } - - /// Apply the transformation patterns in sequence with cleanup transformations - /// interleaved. - void transform(mlir::FuncOp func) const; - - private: - mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const; - - mlir::vector::VectorTransformsOptions vector_transforms_options; - mlir::VectorTransferToSCFOptions vector_to_scf_options; - llvm::SmallVector, 4> transformation_sequence; -}; - -} // namespace mlir_strategy -} // namespace cpu -} // namespace xla - -#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index 52c26d24fe7..a383b4a4a00 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -52,12 +52,6 @@ class TargetMachineFeatures { virtual int vector_register_num_elements(const llvm::Function& function, PrimitiveType type) const = 0; - // Return the number of vector registers. We need to pass in - // "function" since llvm functions can contain annotations for specializing - // them to specific micro-architectures (though currently XLA does not use - // this functionality). - virtual int vector_register_count(const llvm::Function& function) const = 0; - // Returns the minimum alignment for a buffer of size size_bytes. virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0; @@ -90,12 +84,6 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures { (primitive_util::BitWidth(type) / 8); } - int vector_register_count(const llvm::Function& function) const override { - llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); - return static_cast(tti->getNumberOfRegisters( - tti->getRegisterClassForType(/*Vector=*/true))); - } - int64 minimum_alignment_for_allocation(int64 size_bytes) const override; private: diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h index fbbd0d2233d..ffc6927cbe1 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h @@ -44,10 +44,6 @@ class TargetMachineFeaturesWithFakeAlignmentLogic LOG(FATAL) << "Unexpected call to " << __func__; } - int vector_register_count(const llvm::Function& function) const override { - LOG(FATAL) << "Unexpected call to " << __func__; - } - int64 minimum_alignment_for_allocation(int64 size_bytes) const override { return fake_alignment_logic_(size_bytes); } From d93971b09e699f646f773552d53cf3caad28bdea Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Tue, 7 Jul 2020 21:35:40 -0700 Subject: [PATCH 22/88] Update minimize_loss_test to not rely on Keras. Also update the other related tests to use the extracted utils in strategy_test_lib.py PiperOrigin-RevId: 320122161 Change-Id: I0a8f66d19b8f6cf32978f3386d7cbbcbe3dc4b84 --- tensorflow/python/distribute/BUILD | 4 +-- .../python/distribute/minimize_loss_test.py | 27 ++++++++++++------- .../parameter_server_strategy_test.py | 8 ++---- .../python/distribute/single_loss_example.py | 13 +++------ .../python/distribute/strategy_test_lib.py | 26 ++++++++++++------ 5 files changed, 42 insertions(+), 36 deletions(-) diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 89ce2384a3b..2fd1fc93cd3 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -1333,6 +1333,7 @@ distribute_py_test( ":mirrored_strategy", ":single_loss_example", ":strategy_combinations", + ":strategy_test_lib", "//tensorflow/python:control_flow_ops", "//tensorflow/python:control_flow_v2_toggles", "//tensorflow/python:framework_ops", @@ -1342,8 +1343,6 @@ distribute_py_test( "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/eager:context", "//tensorflow/python/eager:test", - "//tensorflow/python/keras/layers", - "//tensorflow/python/keras/optimizer_v2", "//tensorflow/python/ops/losses", "//third_party/py/numpy", "@absl_py//absl/testing:parameterized", @@ -1355,6 +1354,7 @@ py_library( srcs = ["single_loss_example.py"], deps = [ ":step_fn", + ":strategy_test_lib", "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", "//tensorflow/python:layers", diff --git a/tensorflow/python/distribute/minimize_loss_test.py b/tensorflow/python/distribute/minimize_loss_test.py index c9df971783c..434458182b9 100644 --- a/tensorflow/python/distribute/minimize_loss_test.py +++ b/tensorflow/python/distribute/minimize_loss_test.py @@ -20,22 +20,24 @@ from __future__ import print_function from absl.testing import parameterized import numpy + from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import combinations from tensorflow.python.distribute import reduce_util from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.distribute.single_loss_example import batchnorm_example from tensorflow.python.distribute.single_loss_example import minimize_loss_example from tensorflow.python.eager import context from tensorflow.python.eager import test from tensorflow.python.framework import constant_op +from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops -from tensorflow.python.keras.layers import core -from tensorflow.python.keras.optimizer_v2 import optimizer_v2 from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import control_flow_v2_toggles from tensorflow.python.ops import math_ops +from tensorflow.python.ops import nn_ops from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variables as variables_lib from tensorflow.python.ops.losses import losses_impl @@ -208,7 +210,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def get_expected_variables(num_parameter_devices): name = optimizer._name - if isinstance(optimizer, optimizer_v2.OptimizerV2): + if strategy_test_lib.is_optimizer_v2_instance(optimizer): variables = VAR_MAP_V2[name] else: variables = VAR_MAP_V1[name] @@ -349,7 +351,7 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): optimizer = optimizer_fn() # GradientDescent with 0.2 learning rate - if isinstance(optimizer, optimizer_v2.OptimizerV2): + if strategy_test_lib.is_optimizer_v2_instance(optimizer): return optimizer.minimize(loss_fn, [w]) else: if use_callable_loss: @@ -426,7 +428,11 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): return dataset.batch(batch_size=1, drop_remainder=True) optimizer = optimizer_fn() - layer = core.Dense(1, use_bias=True) + kernel = strategy_test_lib.create_variable_like_keras_layer( + "kernel", (1, 1), dtypes.float32) + bias = strategy_test_lib.create_variable_like_keras_layer( + "bias", (1,), dtypes.float32) + # layer = core.Dense(1, use_bias=True) key1 = "foo" value1 = "bar" @@ -434,12 +440,13 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): def model_fn(output_context, x): """A very simple model written by the user.""" def loss_fn(): - y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) + y = array_ops.reshape(nn_ops.bias_add( + math_ops.matmul(x, kernel), bias), []) - constant_op.constant(1.) return y * y - if isinstance(optimizer, optimizer_v2.OptimizerV2): + if strategy_test_lib.is_optimizer_v2_instance(optimizer): train_op = optimizer.minimize( - loss_fn, lambda: layer.trainable_variables) + loss_fn, lambda: [kernel, bias]) else: train_op = optimizer.minimize(loss_fn) loss = loss_fn() @@ -508,8 +515,8 @@ class MinimizeLossStepTest(test.TestCase, parameterized.TestCase): for _ in range(5): _, loss = run_step() losses.append(loss) - weights.append(self.evaluate(layer.kernel)) - biases.append(self.evaluate(layer.bias)) + weights.append(self.evaluate(kernel)) + biases.append(self.evaluate(bias)) loss_is_not_increasing = all(y <= x for x, y in zip(losses, losses[1:])) self.assertTrue(loss_is_not_increasing) diff --git a/tensorflow/python/distribute/parameter_server_strategy_test.py b/tensorflow/python/distribute/parameter_server_strategy_test.py index 494431571d4..a68183adbaa 100644 --- a/tensorflow/python/distribute/parameter_server_strategy_test.py +++ b/tensorflow/python/distribute/parameter_server_strategy_test.py @@ -19,7 +19,6 @@ from __future__ import division from __future__ import print_function import copy -import functools import threading from absl.testing import parameterized @@ -51,7 +50,6 @@ from tensorflow.python.framework import tensor_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import control_flow_ops from tensorflow.python.ops import gradients -from tensorflow.python.ops import init_ops_v2 from tensorflow.python.ops import math_ops from tensorflow.python.ops import partitioned_variables from tensorflow.python.ops import resource_variable_ops @@ -452,10 +450,8 @@ class ParameterServerStrategyTestBase( self.cached_session(target=master_target, config=sess_config) as sess, \ d.scope(): - initializer = functools.partial( - init_ops_v2.GlorotUniform(), (1, 1), dtype=dtypes.float32) - kernel = variables.Variable( - initial_value=initializer, name='kernel', trainable=True) + kernel = strategy_test_lib.create_variable_like_keras_layer( + 'kernel', (1, 1), dtypes.float32,) def loss_fn(x): y = array_ops.reshape( diff --git a/tensorflow/python/distribute/single_loss_example.py b/tensorflow/python/distribute/single_loss_example.py index 289e27d8084..e7fb7c92cb5 100644 --- a/tensorflow/python/distribute/single_loss_example.py +++ b/tensorflow/python/distribute/single_loss_example.py @@ -20,13 +20,13 @@ from __future__ import print_function from tensorflow.python.data.ops import dataset_ops from tensorflow.python.distribute import step_fn +from tensorflow.python.distribute import strategy_test_lib from tensorflow.python.framework import constant_op from tensorflow.python.framework import ops from tensorflow.python.layers import core from tensorflow.python.layers import normalization from tensorflow.python.ops import array_ops from tensorflow.python.ops import math_ops -from tensorflow.python.util import tf_inspect def single_loss_example(optimizer_fn, distribution, use_bias=False, @@ -69,7 +69,7 @@ def minimize_loss_example(optimizer, use_bias=False, use_callable_loss=True): y = array_ops.reshape(layer(x), []) - constant_op.constant(1.) return y * y - if _is_optimizer_v2_instance(optimizer): + if strategy_test_lib.is_optimizer_v2_instance(optimizer): return optimizer.minimize(loss_fn, lambda: layer.trainable_variables) elif use_callable_loss: return optimizer.minimize(loss_fn) @@ -112,17 +112,10 @@ def batchnorm_example(optimizer_fn, # `x` and `y` will be fetched by the gradient computation, but not `loss`. return loss - if _is_optimizer_v2_instance(optimizer): + if strategy_test_lib.is_optimizer_v2_instance(optimizer): return optimizer.minimize(loss_fn, lambda: layer.trainable_variables) # Callable loss. return optimizer.minimize(loss_fn) return model_fn, dataset_fn, batchnorm - - -def _is_optimizer_v2_instance(optimizer): - # For a optimizer instance, the v2 implementation has var_list as a required - # argument. - arg_spec = tf_inspect.getfullargspec(optimizer.minimize) - return 'var_list' in arg_spec.args[:-len(arg_spec.defaults)] diff --git a/tensorflow/python/distribute/strategy_test_lib.py b/tensorflow/python/distribute/strategy_test_lib.py index b8ed0f26ae5..06913db5c72 100644 --- a/tensorflow/python/distribute/strategy_test_lib.py +++ b/tensorflow/python/distribute/strategy_test_lib.py @@ -52,6 +52,7 @@ from tensorflow.python.platform import gfile from tensorflow.python.training import optimizer from tensorflow.python.training import training_util from tensorflow.python.util import nest +from tensorflow.python.util import tf_inspect class _TestException(Exception): @@ -113,18 +114,27 @@ def _events_from_logdir(test_case, logdir): return result +def create_variable_like_keras_layer(name, shape, dtype): + """Utitlity for create variables that works like variable in keras layer.""" + initializer = functools.partial( + init_ops_v2.GlorotUniform(), shape, dtype=dtype) + return variables.Variable( + initial_value=initializer, name=name, trainable=True) + + +def is_optimizer_v2_instance(optimizer_obj): + # For a optimizer instance, the v2 implementation has var_list as a required + # argument. + arg_spec = tf_inspect.getfullargspec(optimizer_obj.minimize) + return "var_list" in arg_spec.args[:-len(arg_spec.defaults)] + + class DistributionTestBase(test.TestCase): """Some tests that should work with any DistributionStrategy.""" - def _create_variable_like_keras_dense_layer(self, name, shape, dtype): - initializer = functools.partial( - init_ops_v2.GlorotUniform(), shape, dtype=dtype) - return variables.Variable( - initial_value=initializer, name=name, trainable=True) - def _test_minimize_loss_eager(self, d): with d.scope(): - kernel = self._create_variable_like_keras_dense_layer( + kernel = create_variable_like_keras_layer( name="kernel", shape=(1, 1), dtype=dtypes.float32) def loss(x): y = array_ops.reshape( @@ -182,7 +192,7 @@ class DistributionTestBase(test.TestCase): ops.Graph().as_default(), \ self.cached_session(config=config) as sess, \ d.scope(): - kernel = self._create_variable_like_keras_dense_layer( + kernel = create_variable_like_keras_layer( name="kernel", shape=(1, 1), dtype=dtypes.float32) def loss(x): From c7b79a8c2325e17c9d6ac900b297d878c62996e4 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 7 Jul 2020 22:33:01 -0700 Subject: [PATCH 23/88] Adding TFConcreteFunction class for ConcreteFunction reloading in the SavedModel C API. PiperOrigin-RevId: 320127741 Change-Id: I002c77dd23dc8d85d1b088e2dbe6dfc8088d2b77 --- .../c/experimental/saved_model/core/BUILD | 50 +--- .../saved_model/core/concrete_function.cc | 32 +++ .../saved_model/core/concrete_function.h | 17 +- .../saved_model/core/revived_types/BUILD | 21 -- .../revived_types/tf_concrete_function.cc | 87 ------ .../core/revived_types/tf_concrete_function.h | 87 ------ .../saved_model/core/saved_model_utils.cc | 137 --------- .../saved_model/core/saved_model_utils.h | 9 - .../core/tf_concrete_function_loading_test.cc | 271 ------------------ .../core/tf_concrete_function_test_protos.cc | 212 -------------- .../core/tf_concrete_function_test_protos.h | 50 ---- .../c/experimental/saved_model/internal/BUILD | 8 +- .../saved_model/internal/concrete_function.cc | 10 +- .../saved_model/public/concrete_function.h | 2 +- 14 files changed, 52 insertions(+), 941 deletions(-) create mode 100644 tensorflow/c/experimental/saved_model/core/concrete_function.cc delete mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc delete mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h delete mode 100644 tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc delete mode 100644 tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc delete mode 100644 tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 34d91679dad..5452907f3e8 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -19,6 +19,9 @@ package( cc_library( name = "concrete_function", + srcs = [ + "concrete_function.cc", + ], hdrs = [ "concrete_function.h", ], @@ -26,6 +29,7 @@ cc_library( ":function_metadata", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:protos_all_cc", ], ) @@ -56,13 +60,10 @@ cc_library( "saved_model_utils.h", ], deps = [ - ":function_metadata", "//tensorflow/c:tf_tensor_internal", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", - "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", - "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], ) @@ -90,18 +91,6 @@ cc_library( ], ) -cc_library( - name = "tf_concrete_function_test_protos", - testonly = True, - srcs = ["tf_concrete_function_test_protos.cc"], - hdrs = ["tf_concrete_function_test_protos.h"], - deps = [ - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "@com_google_absl//absl/strings", - ], -) - cc_library( name = "tf_saved_model_impl", srcs = [ @@ -125,16 +114,12 @@ cc_library( "saved_model_api.h", ], visibility = ["//tensorflow/python:__pkg__"], - deps = [ - "//tensorflow/c/eager:immediate_execution_operation", - "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/core:lib", - ], ) filegroup( name = "mobile_srcs_only_runtime", srcs = [ + "concrete_function.cc", "concrete_function.h", "function_metadata.h", "saved_model_api.h", @@ -187,28 +172,3 @@ tf_cc_test( "//tensorflow/core/common_runtime/eager:core", ], ) - -tf_cc_test( - name = "tf_concrete_function_loading_test", - srcs = [ - "tf_concrete_function_loading_test.cc", - ], - deps = [ - ":saved_model_utils", - ":test_utils", - ":tf_concrete_function_test_protos", - "//tensorflow/c:tensor_interface", - "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/c/experimental/saved_model/core/revived_types:constant", - "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", - "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", - "//tensorflow/core:framework", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core:test", - "//tensorflow/core:test_main", - "//tensorflow/core/common_runtime:core_cpu_lib", - "//tensorflow/core/common_runtime/eager:context", - "//tensorflow/core/common_runtime/eager:core", - ], -) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.cc b/tensorflow/c/experimental/saved_model/core/concrete_function.cc new file mode 100644 index 00000000000..41bae4352fc --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.cc @@ -0,0 +1,32 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" + +namespace tensorflow { + +const std::vector& +ConcreteFunction::GetCaptures() const { + return captures_; +} + +const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const { + return metadata_; +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 2cc627bcf27..22535641ef5 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ -#include #include #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" +#include "tensorflow/core/framework/function.pb.h" namespace tensorflow { @@ -35,14 +35,19 @@ namespace tensorflow { // and have only a single implementation. class ConcreteFunction { public: - virtual ~ConcreteFunction() = default; + virtual ~ConcreteFunction() = 0; // This method returns the "Call" Op used to execute the function. - virtual Status GetCallOp(ImmediateOpPtr* out) = 0; + virtual ImmediateExecutionOperation* GetCallOp() = 0; - virtual const std::vector& GetCaptures() - const = 0; - virtual const FunctionMetadata& GetFunctionMetadata() const = 0; + const std::vector& GetCaptures() + const; + const FunctionMetadata& GetFunctionMetadata() const; + + private: + FunctionMetadata metadata_; + std::vector captures_; + FunctionDef* function_; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 8bb15674db0..84fad2ea8f6 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -58,24 +58,3 @@ cc_library( "//tensorflow/c/eager:immediate_execution_tensor_handle", ], ) - -cc_library( - name = "tf_concrete_function", - srcs = [ - "tf_concrete_function.cc", - ], - hdrs = [ - "tf_concrete_function.h", - ], - deps = [ - ":tensorhandle_convertible", - "//tensorflow/c/eager:immediate_execution_context", - "//tensorflow/c/eager:immediate_execution_operation", - "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/c/experimental/saved_model/core:concrete_function", - "//tensorflow/c/experimental/saved_model/core:function_metadata", - "//tensorflow/core:lib", - "//tensorflow/core:protos_all_cc", - "//tensorflow/core/common_runtime/eager:context", - ], -) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc deleted file mode 100644 index aa6f0e7205e..00000000000 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" - -#include -#include - -#include "tensorflow/c/eager/immediate_execution_operation.h" -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" -#include "tensorflow/core/common_runtime/eager/context.h" -#include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/platform/errors.h" -#include "tensorflow/core/platform/logging.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" -#include "tensorflow/core/protobuf/struct.pb.h" - -namespace tensorflow { - -TFConcreteFunction::TFConcreteFunction( - const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx) - : name_(name), - captures_(std::move(captures)), - metadata_(std::move(metadata)), - ctx_(ctx) {} - -TFConcreteFunction::~TFConcreteFunction() { - Status status = ctx_->RemoveFunction(name_); - if (!status.ok()) { - LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " - << status.error_message(); - } -} - -Status TFConcreteFunction::Create( - const FunctionDef* function_def, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx, - std::unique_ptr* out) { - TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); - out->reset(new TFConcreteFunction(function_def->signature().name(), - std::move(captures), std::move(metadata), - ctx)); - return Status(); -} - -const std::vector& -TFConcreteFunction::GetCaptures() const { - return captures_; -} - -const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { - return metadata_; -} - -Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) { - out->reset(ctx_->CreateOperation()); - // In eager mode, TF2 python executes functions by constructing an op with - // the name of the functiondef: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 - // In graph mode, we create a PartitionedCallOp instead: - // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 - - // TODO(bmzhao): After discussing with Allen, we should execute this via a - // PartitionedCallOp for compatibility with "tooling that assumes functions in - // graphs are PartitionedCallOps". - TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); - return Status(); -} - -} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h deleted file mode 100644 index 71c8322414d..00000000000 --- a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h +++ /dev/null @@ -1,87 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ -#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ - -#include -#include -#include -#include -#include - -#include "tensorflow/c/eager/immediate_execution_context.h" -#include "tensorflow/c/eager/immediate_execution_operation.h" -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" -#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" -#include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" - -namespace tensorflow { - -// TF Eager Runtime-based implementation of a "ConcreteFunction" loaded from a -// saved model. -class TFConcreteFunction : public ConcreteFunction { - public: - // Factory function for creating a TFConcreteFunction. - // - // Params: - // function_def - The function_def associated with the created - // TFConcreteFunction. TFConcreteFunction will register this - // function_def with `ctx` on creation, and de-register it on - // destruction. function_def must be non-null, but - // otherwise has no lifetime requirements. - // captures - The captured TensorHandles associated with this - // TFConcreteFunction. - // metadata - The FunctionMetadata associated with this TFConcreteFunction. - // ctx - A handle to the Tensorflow runtime. This MUST be non-null and - // outlive TFConcreteFunction. - // out - The output TFConcreteFunction. - static Status Create(const FunctionDef* function_def, - std::vector captures, - FunctionMetadata metadata, - ImmediateExecutionContext* ctx, - std::unique_ptr* out); - - // This method returns the "Call" Op used to execute the function. - Status GetCallOp(ImmediateOpPtr* out) override; - - const std::vector& GetCaptures() - const override; - - const FunctionMetadata& GetFunctionMetadata() const override; - - ~TFConcreteFunction() override; - - private: - TFConcreteFunction(const std::string& name, - std::vector captures, - FunctionMetadata metadata, ImmediateExecutionContext* ctx); - - TFConcreteFunction(const TFConcreteFunction&) = delete; - TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; - - // Name of the FunctionDef corresponding to this TFConcreteFunction - std::string name_; - std::vector captures_; - FunctionMetadata metadata_; - ImmediateExecutionContext* ctx_; -}; - -} // namespace tensorflow - -#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 4b1d7672ed6..196420eb537 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -17,125 +17,14 @@ limitations under the License. #include -#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/framework/tensor.pb.h" -#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" -#include "tensorflow/core/protobuf/struct.pb.h" namespace tensorflow { namespace internal { -namespace { - -// This returns the size of `tf.nest.flatten(value)`, on values that are -// used in tf.function's input_signatures. -int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) { - // This follows the logic from - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775 - switch (value.kind_case()) { - case StructuredValue::kDictValue: { - const DictValue& dict = value.dict_value(); - int size = 0; - for (const auto& field : dict.fields()) { - size += FlattenedSize(field.second, status); - } - return size; - } - case StructuredValue::kTupleValue: { - const TupleValue& tuple = value.tuple_value(); - int size = 0; - for (const StructuredValue& value : tuple.values()) { - size += FlattenedSize(value, status); - } - return size; - } - case StructuredValue::kListValue: { - const ListValue& list = value.list_value(); - int size = 0; - for (const StructuredValue& value : list.values()) { - size += FlattenedSize(value, status); - } - return size; - } - case StructuredValue::kTensorSpecValue: { - return 1; - } - case StructuredValue::kNoneValue: { - // Base case: do nothing. - // This arises, for example, as the top-level object of an output - // signature when there are no return values. - return 0; - } - default: { - status->Update(errors::Internal("Unhandled structured value kind ", - value.kind_case())); - return 0; - } - } -} - -// Perform some basic sanity checks on SavedConcreteFunction's input and -// output signatures with respect to the corresponding FunctionDef's input -// and output args. -Status ValidateSavedFunctionCompatibleWithFunctionDef( - const SavedConcreteFunction& saved_concrete_function, - const FunctionDef* function_def) { - // tf.functions go through many transformations before becoming FunctionDefs - // 1. flatten user-provided inputs: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675 - // 2. convert user-provided inputs to tensors: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688 - // 3. filter any non-tensor, non-variable inputs: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841 - // 4. concatenate any captured inputs: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912 - - // Since our API is limited to tf.functions annotated with input signatures, - // conditions 2 and 3 are trivially satisfied. - // We need to ensure that: - // flatten(input_signature).size() + captures.size() = fdef.signature().size() - // A concrete function's serialized "canonicalized_input_signature" comes - // from encoding its "structured_input_signature" field: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71 - // The "structured_input_signature" is guaranteed to be a tuple of the python - // args, kwargs that correspond to the tf.function: - // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979 - - const std::string& name = function_def->signature().name(); - const StructuredValue& input_signature = - saved_concrete_function.canonicalized_input_signature(); - Status status; - int input_signature_size = FlattenedSize(input_signature, &status); - TF_RETURN_IF_ERROR(status); - if (input_signature_size + saved_concrete_function.bound_inputs_size() != - function_def->signature().input_arg_size()) { - return errors::FailedPrecondition( - "FunctionDef ", name, " has ", - function_def->signature().input_arg_size(), - " inputs, but the SavedConcreteFunction has ", input_signature_size, - " flattened user inputs and ", - saved_concrete_function.bound_inputs_size(), " captured inputs."); - } - - const StructuredValue& output_signature = - saved_concrete_function.output_signature(); - int output_signature_size = FlattenedSize(output_signature, &status); - TF_RETURN_IF_ERROR(status); - if (output_signature_size != function_def->signature().output_arg_size()) { - return errors::FailedPrecondition( - "FunctionDef ", name, " has ", - function_def->signature().output_arg_size(), - " outputs, but the SavedConcreteFunction has ", output_signature_size, - " flattened outputs."); - } - - return status; -} - -} // namespace Status TensorProtoToConstant(ImmediateExecutionContext* ctx, const TensorProto& proto, @@ -165,31 +54,5 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, return Status(); } -Status LoadTFConcreteFunction( - const SavedConcreteFunction& saved_concrete_function, - const FunctionDef* function_def, - const std::unordered_map>& - captured_objects, - ImmediateExecutionContext* ctx, std::unique_ptr* out) { - TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef( - saved_concrete_function, function_def)); - - // Copy over captures - std::vector captures; - captures.reserve(saved_concrete_function.bound_inputs_size()); - for (int bound_input : saved_concrete_function.bound_inputs()) { - auto iter = captured_objects.find(bound_input); - if (iter == captured_objects.end()) { - return errors::FailedPrecondition("Failed to find bound_input ", - bound_input, - " for SavedConcreteFunction"); - } - captures.push_back(iter->second->handle()); - } - - return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx, - out); -} - } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index 89a959a89d4..ab1531709e4 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -21,7 +21,6 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" @@ -44,14 +43,6 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, const SavedVariable& variable, std::unique_ptr* output); -// Creates a TFConcreteFunction from a SavedConcreteFunction. -Status LoadTFConcreteFunction( - const SavedConcreteFunction& saved_concrete_function, - const FunctionDef* function_def, - const std::unordered_map>& - captured_objects, - ImmediateExecutionContext* ctx, std::unique_ptr* out); - } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc deleted file mode 100644 index 05fbac13077..00000000000 --- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc +++ /dev/null @@ -1,271 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include -#include - -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" -#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" -#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" -#include "tensorflow/c/experimental/saved_model/core/test_utils.h" -#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" -#include "tensorflow/core/framework/function.pb.h" -#include "tensorflow/core/framework/op_def.pb.h" -#include "tensorflow/core/framework/types.pb.h" -#include "tensorflow/core/lib/core/status_test_util.h" -#include "tensorflow/core/platform/status.h" -#include "tensorflow/core/platform/test.h" -#include "tensorflow/core/protobuf/error_codes.pb.h" -#include "tensorflow/core/protobuf/saved_object_graph.pb.h" - -namespace tensorflow { -namespace { - -class SavedConcreteFunctionLoadingTest : public ::testing::Test { - public: - SavedConcreteFunctionLoadingTest() - : device_mgr_(testing::CreateTestingDeviceMgr()), - ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} - - EagerContext* context() { return ctx_.get(); } - - private: - std::unique_ptr device_mgr_; - EagerContextPtr ctx_; -}; - -class DummyCapture : public TensorHandleConvertible { - public: - DummyCapture(ImmediateExecutionContext* ctx, int8 value) - : TensorHandleConvertible( - testing::CreateTensorHandle(ctx, DT_FLOAT, {2, 4}, value)) {} -}; - -FunctionDef FuncDefWithNumInputsOutputs(int num_inputs, int num_outputs) { - FunctionDef func; - OpDef* signature = func.mutable_signature(); - for (int i = 0; i < num_inputs; ++i) { - signature->add_input_arg(); - } - for (int i = 0; i < num_outputs; ++i) { - signature->add_output_arg(); - } - return func; -} - -// A SavedConcreteFunction whose canonicalized input signature -// has less inputs than its corresponding FunctionDef should cause an error. -TEST_F(SavedConcreteFunctionLoadingTest, TooFewInputsInSavedConcreteFunction) { - // `saved` has 1 input - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::SingleArgInputSignature(); - *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); - - // `func` has 2 inputs - FunctionDef func = FuncDefWithNumInputsOutputs(2, 0); - - std::unique_ptr result; - Status status = - internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); -} - -// A SavedConcreteFunction whose canonicalized input signature length + -// captures is less than its corresponding FunctionDef should cause an error. -TEST_F(SavedConcreteFunctionLoadingTest, - TooFewInputsWithCapturesInSavedConcreteFunction) { - // `saved` has 1 input, and 1 capture, for a total of 2 inputs - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::SingleArgInputSignature(); - *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); - saved.add_bound_inputs(5); - - // `func` has 3 inputs - FunctionDef func = FuncDefWithNumInputsOutputs(3, 0); - - std::unordered_map> captures; - captures[5] = std::make_unique(context(), 10); - - std::unique_ptr result; - Status status = internal::LoadTFConcreteFunction(saved, &func, captures, - context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); -} - -// A SavedConcreteFunction whose canonicalized input signature -// has more inputs than its corresponding FunctionDef should cause an error. -TEST_F(SavedConcreteFunctionLoadingTest, TooManyInputsInSavedConcreteFunction) { - // `saved` has 3 inputs - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::ThreeArgInputSignature(); - *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); - - // `func` has 2 inputs - FunctionDef func = FuncDefWithNumInputsOutputs(2, 0); - - std::unique_ptr result; - Status status = - internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); -} - -// A SavedConcreteFunction whose canonicalized input signature -// has the same number of inputs than its corresponding FunctionDef, but has -// additional captures should cause an error. -TEST_F(SavedConcreteFunctionLoadingTest, - TooManyInputsWithCaptureInSavedConcreteFunction) { - // `saved` has 3 inputs, and 1 capture, for a total of 4 inputs. - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::ThreeArgInputSignature(); - *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); - saved.add_bound_inputs(5); - - // `func` has 3 inputs. - FunctionDef func = FuncDefWithNumInputsOutputs(3, 0); - - std::unordered_map> captures; - captures[5] = std::make_unique(context(), 10); - - std::unique_ptr result; - Status status = internal::LoadTFConcreteFunction(saved, &func, captures, - context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); -} - -// A SavedConcreteFunction whose capture refers to an index not in the capture -// map should cause an error. -TEST_F(SavedConcreteFunctionLoadingTest, ImproperCaptureIndex) { - // `saved` has 3 inputs, 1 capture, for a total of 4 inputs - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::ThreeArgInputSignature(); - *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); - // Capture is at index "10" - saved.add_bound_inputs(10); - - // `func` has 4 inputs - FunctionDef func = FuncDefWithNumInputsOutputs(4, 0); - - // `captures` only has a capture for index 5 - std::unordered_map> captures; - captures[5] = std::make_unique(context(), 10); - - std::unique_ptr result; - Status status = internal::LoadTFConcreteFunction(saved, &func, captures, - context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); -} - -// A SavedConcreteFunction whose outputs are fewer than its corresponding -// functiondef should cause an error. -TEST_F(SavedConcreteFunctionLoadingTest, TooFewOutputsInSavedConcreteFunction) { - // `saved` has 0 inputs, 1 output - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::ZeroArgInputSignature(); - *saved.mutable_output_signature() = testing::SingleReturnOutputSignature(); - - // `func` has 0 inputs, 2 outputs - FunctionDef func = FuncDefWithNumInputsOutputs(0, 2); - - std::unique_ptr result; - Status status = - internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); -} - -// A SavedConcreteFunction whose outputs exceed its corresponding functiondef -// should cause an error. -TEST_F(SavedConcreteFunctionLoadingTest, - TooManyOutputsInSavedConcreteFunction) { - // `saved` has 1 input, 3 outputs - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::SingleArgInputSignature(); - *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature(); - - // `func` has 1 input, 2 outputs - FunctionDef func = FuncDefWithNumInputsOutputs(1, 2); - - std::unique_ptr result; - Status status = - internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) - << status.error_message(); -} - -// A SavedConcreteFunction whose (inputs + captures) = functiondef inputs, -// and whose outputs = functiondef outputs should successfully load. -TEST_F(SavedConcreteFunctionLoadingTest, SuccessfulLoad) { - // `saved` has 1 input, 2 captures, 3 outputs - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::SingleArgInputSignature(); - *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature(); - saved.add_bound_inputs(2); - saved.add_bound_inputs(5); - - // `func` has 3 inputs, 3 outputs - FunctionDef func = FuncDefWithNumInputsOutputs(3, 3); - - std::unordered_map> captures; - captures[2] = std::make_unique(context(), 1); - captures[5] = std::make_unique(context(), 10); - - std::unique_ptr result; - Status status = internal::LoadTFConcreteFunction(saved, &func, captures, - context(), &result); - TF_EXPECT_OK(status) << status.error_message(); -} - -// A TFConcreteFunction should register functiondefs on creation, and -// remove them upon deletion. -TEST_F(SavedConcreteFunctionLoadingTest, RegistersAndRemovesFunctionDefs) { - std::string func_name = "FooBarBazWombatFunction"; - - SavedConcreteFunction saved; - *saved.mutable_canonicalized_input_signature() = - testing::ZeroArgInputSignature(); - *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); - FunctionDef func = FuncDefWithNumInputsOutputs(0, 0); - *func.mutable_signature()->mutable_name() = func_name; - - { - std::unique_ptr result; - Status status = - internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); - TF_EXPECT_OK(status) << status.error_message(); - // The function should be registered with context. - EXPECT_TRUE(context()->FindFunctionByName(func_name)); - } - - // After `result's` destructor runs, the function should no longer be - // registered with context. - EXPECT_FALSE(context()->FindFunctionByName(func_name)); -} - -} // namespace -} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc deleted file mode 100644 index dc69cf6203c..00000000000 --- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc +++ /dev/null @@ -1,212 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" - -#include - -#include "absl/strings/string_view.h" -#include "tensorflow/core/platform/protobuf.h" -#include "tensorflow/core/protobuf/struct.pb.h" - -namespace tensorflow { -namespace testing { -namespace { - -constexpr absl::string_view kZeroArgInputSignatureTextProto = R"( -tuple_value: { - values: { - tuple_value: { - } - } - values: { - dict_value: { - } - } -} -)"; - -constexpr absl::string_view kSingleArgInputSignatureTextProto = R"( -tuple_value: { - values: { - tuple_value: { - values: { - tensor_spec_value: { - name : "x" - shape: { - dim: { - size: 1 - } - dim: { - size: 10 - } - } - dtype: DT_FLOAT - } - } - } - } - values: { - dict_value: { - } - } -} -)"; - -constexpr absl::string_view kThreeArgInputSignatureTextProto = R"( -tuple_value: { - values: { - tuple_value: { - values: { - tensor_spec_value: { - name : "x" - shape: { - dim: { - size: 1 - } - } - dtype: DT_FLOAT - } - } - values: { - tensor_spec_value: { - name : "y" - shape: { - dim: { - size: 1 - } - } - dtype: DT_FLOAT - } - } - values: { - tensor_spec_value: { - name : "z" - shape: { - dim: { - size: 1 - } - } - dtype: DT_FLOAT - } - } - } - } - values: { - dict_value: { - } - } -} - -)"; - -constexpr absl::string_view kZeroReturnOutputSignatureTextProto = R"( -none_value: {} -)"; - -constexpr absl::string_view kSingleReturnOutputSignatureTextProto = R"( -tensor_spec_value: { - shape: { - dim: { - size: 1 - } - } - dtype: DT_FLOAT -} -)"; - -constexpr absl::string_view kThreeReturnOutputSignatureTextProto = R"( -tuple_value: { - values: { - dict_value: { - fields: { - key : "a" - value: { - tensor_spec_value: { - name : "0/a" - shape: { - dim: { - size: 1 - } - } - dtype: DT_FLOAT - } - } - } - fields: { - key : "b" - value: { - tensor_spec_value: { - name : "0/b" - shape: { - dim: { - size: 1 - } - } - dtype: DT_FLOAT - } - } - } - } - } - values: { - tensor_spec_value: { - name : "1" - shape: { - dim: { - size: 1 - } - } - dtype: DT_FLOAT - } - } -} -)"; - -StructuredValue ParseStructuredValue(absl::string_view text_proto) { - StructuredValue value; - CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), - &value)); - return value; -} - -} // namespace - -StructuredValue ZeroArgInputSignature() { - return ParseStructuredValue(kZeroArgInputSignatureTextProto); -} - -StructuredValue SingleArgInputSignature() { - return ParseStructuredValue(kSingleArgInputSignatureTextProto); -} - -StructuredValue ThreeArgInputSignature() { - return ParseStructuredValue(kThreeArgInputSignatureTextProto); -} - -StructuredValue ZeroReturnOutputSignature() { - return ParseStructuredValue(kZeroReturnOutputSignatureTextProto); -} - -StructuredValue SingleReturnOutputSignature() { - return ParseStructuredValue(kSingleReturnOutputSignatureTextProto); -} - -StructuredValue ThreeReturnOutputSignature() { - return ParseStructuredValue(kThreeReturnOutputSignatureTextProto); -} - -} // namespace testing -} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h deleted file mode 100644 index 8aa7d5694e1..00000000000 --- a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h +++ /dev/null @@ -1,50 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ -#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ - -#include "tensorflow/core/protobuf/struct.pb.h" - -namespace tensorflow { -namespace testing { - -// Returns a StructuredValue corresponding to the serialized InputSignature of a -// tf.function with 0 inputs -StructuredValue ZeroArgInputSignature(); - -// Returns a StructuredValue corresponding to the serialized InputSignature of a -// tf.function with 1 input -StructuredValue SingleArgInputSignature(); - -// Returns a StructuredValue corresponding to the serialized InputSignature of a -// tf.function with 3 inputs -StructuredValue ThreeArgInputSignature(); - -// Returns a StructuredValue corresponding to the serialized OutputSignature of -// a tf.function with no return values -StructuredValue ZeroReturnOutputSignature(); - -// Returns a StructuredValue corresponding to the serialized OutputSignature of -// a tf.function with a single tensor output -StructuredValue SingleReturnOutputSignature(); - -// Returns a StructuredValue corresponding to the serialized OutputSignature of -// a tf.function with three tensor outputs -StructuredValue ThreeReturnOutputSignature(); - -} // namespace testing -} // namespace tensorflow -#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 6be2a02eeb8..888c284bb12 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -41,13 +41,11 @@ cc_library( ":tensorhandle_list", ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", - "//tensorflow/c:tf_status_internal", "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:c_api_internal", "//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:function_metadata", - "//tensorflow/core:lib", ], ) @@ -207,13 +205,9 @@ tf_cc_test( ], deps = [ "//tensorflow/c:tf_status", - "//tensorflow/c:tf_tensor", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", - "//tensorflow/c/eager:c_api_test_util", - "//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", - "//tensorflow/c/experimental/saved_model/public:tensorhandle_list", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index 12d49212a88..dd54416ddf9 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -15,15 +15,12 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" -#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" #include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" -#include "tensorflow/c/tf_status_internal.h" -#include "tensorflow/core/platform/status.h" extern "C" { @@ -37,11 +34,8 @@ const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); } -TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, - TF_Status* status) { - tensorflow::ImmediateOpPtr call_op(nullptr); - status->status = tensorflow::unwrap(func)->GetCallOp(&call_op); - return tensorflow::wrap(call_op.release()); +TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { + return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp()); } } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 944ddecea16..2a87214270c 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -41,7 +41,7 @@ TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( // Returns a TFE_Op suitable for executing this function. TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( - TF_ConcreteFunction* func, TF_Status* status); + TF_ConcreteFunction* func); #ifdef __cplusplus } // end extern "C" From 44de16f84ab41883f1de88061f482d9754681478 Mon Sep 17 00:00:00 2001 From: Mihai Maruseac Date: Tue, 7 Jul 2020 22:45:30 -0700 Subject: [PATCH 24/88] Strictify Python targets in tensorflow/python/saved_model/. PiperOrigin-RevId: 320128932 Change-Id: I3a12aff52bb2b23025028643bca2756d55870809 --- tensorflow/python/saved_model/BUILD | 127 +++++------------- .../python/saved_model/model_utils/BUILD | 31 ++--- tensorflow/tensorflow.bzl | 4 - 3 files changed, 45 insertions(+), 117 deletions(-) diff --git a/tensorflow/python/saved_model/BUILD b/tensorflow/python/saved_model/BUILD index cf90fe0eeed..1fc6253f763 100644 --- a/tensorflow/python/saved_model/BUILD +++ b/tensorflow/python/saved_model/BUILD @@ -1,7 +1,6 @@ # Description: # TensorFlow SavedModel. -load("//tensorflow:tensorflow.bzl", "py_strict_library") load("//tensorflow:tensorflow.bzl", "cuda_py_test") # buildifier: disable=same-origin-load @@ -16,7 +15,7 @@ package( exports_files(["LICENSE"]) -py_strict_library( +py_library( name = "saved_model", srcs = ["saved_model.py"], srcs_version = "PY2AND3", @@ -38,28 +37,28 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "constants", srcs = ["constants.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:tf_export"], + deps = ["//tensorflow/python:util"], ) -py_strict_library( +py_library( name = "signature_constants", srcs = ["signature_constants.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:tf_export"], + deps = ["//tensorflow/python:util"], ) -py_strict_library( +py_library( name = "tag_constants", srcs = ["tag_constants.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:tf_export"], + deps = ["//tensorflow/python:util"], ) -py_strict_library( +py_library( name = "builder", srcs = [ "builder.py", @@ -68,20 +67,18 @@ py_strict_library( srcs_version = "PY2AND3", deps = [ ":constants", - ":signature_def_utils", ":utils", "//tensorflow/core:protos_all_py", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:saver", - "//tensorflow/python:tf_export", "//tensorflow/python:util", "//tensorflow/python:variables", ], ) -py_strict_library( +py_library( name = "loader", srcs = [ "loader.py", @@ -97,7 +94,6 @@ py_strict_library( "//tensorflow/python:lib", "//tensorflow/python:platform", "//tensorflow/python:saver", - "//tensorflow/python:tf_export", "//tensorflow/python:util", "//tensorflow/python:variables", ], @@ -125,7 +121,7 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "simple_save", srcs = [ "simple_save.py", @@ -136,13 +132,13 @@ py_strict_library( ":signature_constants", ":signature_def_utils", ":tag_constants", - "//tensorflow/python:framework_ops", - "//tensorflow/python:tf_export", + "//tensorflow/core:protos_all_py", + "//tensorflow/python:lib", "//tensorflow/python:util", ], ) -py_strict_library( +py_library( name = "main_op", srcs = [ "main_op.py", @@ -153,7 +149,6 @@ py_strict_library( "//tensorflow/python:control_flow_ops", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lookup_ops", - "//tensorflow/python:tf_export", "//tensorflow/python:util", "//tensorflow/python:variables", ], @@ -190,7 +185,7 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "utils", srcs = [ "utils.py", @@ -201,13 +196,10 @@ py_strict_library( ":constants", ":nested_structure_coder", "//tensorflow/core:protos_all_py", - "//tensorflow/python:composite_tensor", "//tensorflow/python:framework_for_generated_wrappers", "//tensorflow/python:lib", "//tensorflow/python:sparse_tensor", - "//tensorflow/python:tf_export", "//tensorflow/python:util", - "//tensorflow/python/eager:context", ], ) @@ -225,7 +217,7 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "signature_def_utils", srcs = [ "signature_def_utils.py", @@ -236,9 +228,7 @@ py_strict_library( ":signature_constants", ":utils", "//tensorflow/core:protos_all_py", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:tf_export", "//tensorflow/python:util", ], ) @@ -273,18 +263,16 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "signature_serialization", srcs = [ "signature_serialization.py", ], srcs_version = "PY2AND3", deps = [ - ":function_serialization", ":revived_types", ":signature_constants", "//tensorflow/python:framework_ops", - "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_spec", "//tensorflow/python:util", "//tensorflow/python/eager:def_function", @@ -293,7 +281,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "save_context", srcs = [ "save_context.py", @@ -302,7 +290,7 @@ py_strict_library( deps = [], ) -py_strict_library( +py_library( name = "save", srcs = [ "save.py", @@ -326,22 +314,16 @@ py_strict_library( "//tensorflow/python:constant_op", "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", - "//tensorflow/python:error_interpolation", - "//tensorflow/python:errors", "//tensorflow/python:framework", "//tensorflow/python:framework_ops", "//tensorflow/python:lib", "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_util", - "//tensorflow/python:tf_export", "//tensorflow/python:util", - "//tensorflow/python:versions", "//tensorflow/python/eager:context", "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", "//tensorflow/python/training/saving:checkpoint_options", "//tensorflow/python/training/saving:functional_saver", - "//tensorflow/python/training/saving:saveable_object_util", "//tensorflow/python/training/tracking", "//tensorflow/python/training/tracking:base", "//tensorflow/python/training/tracking:graph_view", @@ -367,13 +349,14 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "load", srcs = [ "load.py", ], srcs_version = "PY2AND3", deps = [ + ":constants", ":function_deserialization", ":load_options", ":load_v1_in_v2", @@ -382,25 +365,17 @@ py_strict_library( ":revived_types", ":utils", "//tensorflow/core:protos_all_py", - "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:control_flow_ops", - "//tensorflow/python:dtypes", - "//tensorflow/python:errors", "//tensorflow/python:framework_ops", - "//tensorflow/python:gradients", - "//tensorflow/python:lookup_ops", + "//tensorflow/python:init_ops", + "//tensorflow/python:lib", "//tensorflow/python:resource_variable_ops", "//tensorflow/python:tensor_util", - "//tensorflow/python:tf_export", "//tensorflow/python:util", "//tensorflow/python:variables", "//tensorflow/python/distribute:distribute_lib", "//tensorflow/python/distribute:distribute_utils", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:function", - "//tensorflow/python/training/saving:checkpoint_options", - "//tensorflow/python/training/saving:saveable_object_util", + "//tensorflow/python/distribute:values", "//tensorflow/python/training/tracking", "//tensorflow/python/training/tracking:base", "//tensorflow/python/training/tracking:graph_view", @@ -408,26 +383,17 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "load_v1_in_v2", srcs = [ "load_v1_in_v2.py", ], srcs_version = "PY2AND3", deps = [ - ":function_deserialization", ":loader", ":signature_serialization", - "//tensorflow/python:array_ops", "//tensorflow/python:constant_op", - "//tensorflow/python:dtypes", - "//tensorflow/python:framework_ops", - "//tensorflow/python:platform", "//tensorflow/python:saver", - "//tensorflow/python:sparse_tensor", - "//tensorflow/python:training_lib", - "//tensorflow/python/eager:context", - "//tensorflow/python/eager:lift_to_graph", "//tensorflow/python/eager:wrap_function", "//tensorflow/python/training/tracking", ], @@ -485,7 +451,7 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "revived_types", srcs = [ "revived_types.py", @@ -506,7 +472,7 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "function_serialization", srcs = [ "function_serialization.py", @@ -515,13 +481,12 @@ py_strict_library( deps = [ ":nested_structure_coder", "//tensorflow/core:protos_all_py", - "//tensorflow/python:func_graph", - "//tensorflow/python:util", + "//tensorflow/python/eager:def_function", "//tensorflow/python/eager:function", ], ) -py_strict_library( +py_library( name = "function_deserialization", srcs = [ "function_deserialization.py", @@ -529,42 +494,22 @@ py_strict_library( srcs_version = "PY2AND3", deps = [ ":nested_structure_coder", - "//tensorflow/core:protos_all_py", - "//tensorflow/python:framework_ops", - "//tensorflow/python:func_graph", - "//tensorflow/python:function_def_to_graph", - "//tensorflow/python:op_def_registry", - "//tensorflow/python:platform", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:tensor_spec", - "//tensorflow/python:tf_decorator", - "//tensorflow/python:type_spec", - "//tensorflow/python:util", "//tensorflow/python/eager:def_function", - "//tensorflow/python/eager:function", ], ) -py_strict_library( +py_library( name = "nested_structure_coder", srcs = ["nested_structure_coder.py"], deps = [ "//tensorflow/core:protos_all_py", - "//tensorflow/python:dtypes", - "//tensorflow/python:indexed_slices", - "//tensorflow/python:resource_variable_ops", - "//tensorflow/python:sparse_tensor", + "//tensorflow/python:framework", "//tensorflow/python:tensor_array_ops", - "//tensorflow/python:tensor_shape", - "//tensorflow/python:tensor_spec", - "//tensorflow/python:tensor_util", - "//tensorflow/python:util", "//tensorflow/python/data/ops:dataset_ops", "//tensorflow/python/data/ops:iterator_ops", "//tensorflow/python/data/ops:optional_ops", "//tensorflow/python/distribute:values", - "//tensorflow/python/ops/ragged:ragged_tensor", - "//tensorflow/python/ops/ragged:row_partition", + "//tensorflow/python/ops/ragged", "@six_archive//:six", ], ) @@ -580,25 +525,22 @@ tf_py_test( ], ) -py_strict_library( +py_library( name = "save_options", srcs = ["save_options.py"], deps = [ - "//tensorflow/python:tf_export", - "//tensorflow/python:util", "@six_archive//:six", ], ) -py_strict_library( +py_library( name = "load_options", srcs = ["load_options.py"], deps = [ - "//tensorflow/python:tf_export", ], ) -py_strict_library( +py_library( name = "method_name_updater", srcs = ["method_name_updater.py"], srcs_version = "PY2AND3", @@ -607,7 +549,6 @@ py_strict_library( ":loader", "//tensorflow/python:lib", "//tensorflow/python:platform", - "//tensorflow/python:tf_export", "//tensorflow/python:util", ], ) diff --git a/tensorflow/python/saved_model/model_utils/BUILD b/tensorflow/python/saved_model/model_utils/BUILD index 82a33c8e522..70cc89b1946 100644 --- a/tensorflow/python/saved_model/model_utils/BUILD +++ b/tensorflow/python/saved_model/model_utils/BUILD @@ -15,12 +15,7 @@ # Description: # Keras saving and loading libraries. - -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "py_strict_library") - -# buildifier: disable=same-origin-load -load("//tensorflow:tensorflow.bzl", "py_strict_test") +load("//tensorflow:tensorflow.bzl", "py_test") package( default_visibility = ["//tensorflow:__subpackages__"], @@ -29,7 +24,7 @@ package( exports_files(["LICENSE"]) -py_strict_library( +py_library( name = "model_utils", srcs = ["__init__.py"], srcs_version = "PY2AND3", @@ -40,7 +35,7 @@ py_strict_library( ], ) -py_strict_library( +py_library( name = "export_output", srcs = ["export_output.py"], srcs_version = "PY2AND3", @@ -49,11 +44,10 @@ py_strict_library( "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", "//tensorflow/python/saved_model:signature_def_utils", - "@six_archive//:six", ], ) -py_strict_test( +py_test( name = "export_output_test", srcs = ["export_output_test.py"], python_version = "PY3", @@ -67,14 +61,12 @@ py_strict_test( "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", "//tensorflow/python:framework_ops", - "//tensorflow/python:metrics", "//tensorflow/python:sparse_tensor", - "//tensorflow/python/eager:context", "//tensorflow/python/saved_model:signature_constants", ], ) -py_strict_library( +py_library( name = "export_utils", srcs = ["export_utils.py"], srcs_version = "PY2AND3", @@ -86,37 +78,36 @@ py_strict_library( "//tensorflow/python/saved_model:signature_constants", "//tensorflow/python/saved_model:signature_def_utils", "//tensorflow/python/saved_model:tag_constants", - "@six_archive//:six", ], ) -py_strict_test( +py_test( name = "export_test", srcs = ["export_test.py"], python_version = "PY3", srcs_version = "PY2AND3", deps = [ - ":export_output", ":export_utils", - ":mode_keys", "//tensorflow/python:array_ops", "//tensorflow/python:client_testlib", "//tensorflow/python:constant_op", + "//tensorflow/python:control_flow_ops", "//tensorflow/python:dtypes", + "//tensorflow/python:framework_ops", "//tensorflow/python:framework_test_lib", + "//tensorflow/python:sparse_tensor", "//tensorflow/python/saved_model:signature_constants", "//tensorflow/python/saved_model:signature_def_utils", ], ) -py_strict_library( +py_library( name = "mode_keys", srcs = ["mode_keys.py"], srcs_version = "PY2AND3", - deps = ["//tensorflow/python:util"], ) -py_strict_test( +py_test( name = "mode_keys_test", srcs = ["mode_keys_test.py"], python_version = "PY3", diff --git a/tensorflow/tensorflow.bzl b/tensorflow/tensorflow.bzl index 1882251f6f0..a80499ae813 100644 --- a/tensorflow/tensorflow.bzl +++ b/tensorflow/tensorflow.bzl @@ -1932,10 +1932,6 @@ register_extension_info( def py_strict_library(name, **kwargs): native.py_library(name = name, **kwargs) -# Placeholder to use until bazel supports py_strict_test. -def py_strict_test(name, **kwargs): - native.py_test(name = name, **kwargs) - def tf_custom_op_py_library( name, srcs = [], From 6fd3e5e217a5ef4656c40bf0171a025d7b9d9183 Mon Sep 17 00:00:00 2001 From: Priya Gupta Date: Tue, 7 Jul 2020 22:56:56 -0700 Subject: [PATCH 25/88] Remove ReplicaContext.__enter__ and __exit__ methods from the docs since users are not expected to use them directly. PiperOrigin-RevId: 320129981 Change-Id: Ifd3ec638549de95c01147560260b8ecf7f17e29d --- tensorflow/python/distribute/distribute_lib.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 4a57628d4c7..78a199fe782 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -2728,6 +2728,7 @@ class ReplicaContext(object): self._replica_id_in_sync_group = replica_id_in_sync_group self._summary_recording_distribution_strategy = None + @doc_controls.do_not_generate_docs def __enter__(self): _push_per_thread_mode(self._thread_context) @@ -2740,6 +2741,7 @@ class ReplicaContext(object): summary_state.is_recording_distribution_strategy) summary_state.is_recording_distribution_strategy = replica_id_is_zero + @doc_controls.do_not_generate_docs def __exit__(self, exception_type, exception_value, traceback): summary_state = summary_ops_v2._summary_state # pylint: disable=protected-access summary_state.is_recording_distribution_strategy = ( From 89b2b978569784db32d7fd55ac1ba5a21dedff7a Mon Sep 17 00:00:00 2001 From: Dero Gharibian Date: Tue, 7 Jul 2020 23:33:28 -0700 Subject: [PATCH 26/88] Fully init TF_TString in order to mitigate potential msan false positives. Prior to this change, TF_TString_Init intentionally only initialized the first 2 bytes of a 24 byte struct. In some situations, where a temporary TF_TString struct is initialized via TF_TString_Init and subsequently copied, the copying of the uninitialized data may result in an msan use-of-uninitialized-value error. PiperOrigin-RevId: 320133690 Change-Id: I995ba629c4cb0b2a85a25c277caa74cf2dbc3bc3 --- tensorflow/core/platform/ctstring_internal.h | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/tensorflow/core/platform/ctstring_internal.h b/tensorflow/core/platform/ctstring_internal.h index f75fd04f955..9524267176c 100644 --- a/tensorflow/core/platform/ctstring_internal.h +++ b/tensorflow/core/platform/ctstring_internal.h @@ -169,8 +169,7 @@ static inline size_t TF_TString_ToInternalSizeT(size_t size, } static inline void TF_TString_Init(TF_TString *str) { - str->u.smll.size = 0; - str->u.smll.str[0] = '\0'; + memset(str->u.raw.raw, 0, sizeof(TF_TString_Raw)); } static inline void TF_TString_Dealloc(TF_TString *str) { From c688b1ab057b994d112aa3b51218dba82a95fba8 Mon Sep 17 00:00:00 2001 From: Brian Zhao Date: Tue, 7 Jul 2020 23:55:28 -0700 Subject: [PATCH 27/88] Rolling forward with #include of missing logging.h header. PiperOrigin-RevId: 320135405 Change-Id: I282170e72bfc74354d56f11c4e22d67886fc7296 --- .../c/experimental/saved_model/core/BUILD | 50 +++- .../saved_model/core/concrete_function.cc | 32 --- .../saved_model/core/concrete_function.h | 17 +- .../saved_model/core/revived_types/BUILD | 21 ++ .../revived_types/tf_concrete_function.cc | 87 ++++++ .../core/revived_types/tf_concrete_function.h | 87 ++++++ .../saved_model/core/saved_model_utils.cc | 137 +++++++++ .../saved_model/core/saved_model_utils.h | 9 + .../core/tf_concrete_function_loading_test.cc | 271 ++++++++++++++++++ .../core/tf_concrete_function_test_protos.cc | 213 ++++++++++++++ .../core/tf_concrete_function_test_protos.h | 50 ++++ .../c/experimental/saved_model/internal/BUILD | 8 +- .../saved_model/internal/concrete_function.cc | 10 +- .../saved_model/public/concrete_function.h | 2 +- 14 files changed, 942 insertions(+), 52 deletions(-) delete mode 100644 tensorflow/c/experimental/saved_model/core/concrete_function.cc create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc create mode 100644 tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h create mode 100644 tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc create mode 100644 tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc create mode 100644 tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h diff --git a/tensorflow/c/experimental/saved_model/core/BUILD b/tensorflow/c/experimental/saved_model/core/BUILD index 5452907f3e8..34d91679dad 100644 --- a/tensorflow/c/experimental/saved_model/core/BUILD +++ b/tensorflow/c/experimental/saved_model/core/BUILD @@ -19,9 +19,6 @@ package( cc_library( name = "concrete_function", - srcs = [ - "concrete_function.cc", - ], hdrs = [ "concrete_function.h", ], @@ -29,7 +26,6 @@ cc_library( ":function_metadata", "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:immediate_execution_tensor_handle", - "//tensorflow/core:protos_all_cc", ], ) @@ -60,10 +56,13 @@ cc_library( "saved_model_utils.h", ], deps = [ + ":function_metadata", "//tensorflow/c:tf_tensor_internal", "//tensorflow/c/eager:immediate_execution_context", "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", "//tensorflow/c/experimental/saved_model/core/revived_types:variable", + "//tensorflow/core:lib", "//tensorflow/core:protos_all_cc", ], ) @@ -91,6 +90,18 @@ cc_library( ], ) +cc_library( + name = "tf_concrete_function_test_protos", + testonly = True, + srcs = ["tf_concrete_function_test_protos.cc"], + hdrs = ["tf_concrete_function_test_protos.h"], + deps = [ + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "@com_google_absl//absl/strings", + ], +) + cc_library( name = "tf_saved_model_impl", srcs = [ @@ -114,12 +125,16 @@ cc_library( "saved_model_api.h", ], visibility = ["//tensorflow/python:__pkg__"], + deps = [ + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/core:lib", + ], ) filegroup( name = "mobile_srcs_only_runtime", srcs = [ - "concrete_function.cc", "concrete_function.h", "function_metadata.h", "saved_model_api.h", @@ -172,3 +187,28 @@ tf_cc_test( "//tensorflow/core/common_runtime/eager:core", ], ) + +tf_cc_test( + name = "tf_concrete_function_loading_test", + srcs = [ + "tf_concrete_function_loading_test.cc", + ], + deps = [ + ":saved_model_utils", + ":test_utils", + ":tf_concrete_function_test_protos", + "//tensorflow/c:tensor_interface", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core/revived_types:constant", + "//tensorflow/c/experimental/saved_model/core/revived_types:tensorhandle_convertible", + "//tensorflow/c/experimental/saved_model/core/revived_types:tf_concrete_function", + "//tensorflow/core:framework", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core:test", + "//tensorflow/core:test_main", + "//tensorflow/core/common_runtime:core_cpu_lib", + "//tensorflow/core/common_runtime/eager:context", + "//tensorflow/core/common_runtime/eager:core", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.cc b/tensorflow/c/experimental/saved_model/core/concrete_function.cc deleted file mode 100644 index 41bae4352fc..00000000000 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.cc +++ /dev/null @@ -1,32 +0,0 @@ -/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. - -Licensed under the Apache License, Version 2.0 (the "License"); -you may not use this file except in compliance with the License. -You may obtain a copy of the License at - - http://www.apache.org/licenses/LICENSE-2.0 - -Unless required by applicable law or agreed to in writing, software -distributed under the License is distributed on an "AS IS" BASIS, -WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -See the License for the specific language governing permissions and -limitations under the License. -==============================================================================*/ - -#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" - -#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" -#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" - -namespace tensorflow { - -const std::vector& -ConcreteFunction::GetCaptures() const { - return captures_; -} - -const FunctionMetadata& ConcreteFunction::GetFunctionMetadata() const { - return metadata_; -} - -} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/concrete_function.h b/tensorflow/c/experimental/saved_model/core/concrete_function.h index 22535641ef5..2cc627bcf27 100644 --- a/tensorflow/c/experimental/saved_model/core/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/core/concrete_function.h @@ -16,12 +16,12 @@ limitations under the License. #ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ #define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_CONCRETE_FUNCTION_H_ +#include #include #include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/immediate_execution_tensor_handle.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" -#include "tensorflow/core/framework/function.pb.h" namespace tensorflow { @@ -35,19 +35,14 @@ namespace tensorflow { // and have only a single implementation. class ConcreteFunction { public: - virtual ~ConcreteFunction() = 0; + virtual ~ConcreteFunction() = default; // This method returns the "Call" Op used to execute the function. - virtual ImmediateExecutionOperation* GetCallOp() = 0; + virtual Status GetCallOp(ImmediateOpPtr* out) = 0; - const std::vector& GetCaptures() - const; - const FunctionMetadata& GetFunctionMetadata() const; - - private: - FunctionMetadata metadata_; - std::vector captures_; - FunctionDef* function_; + virtual const std::vector& GetCaptures() + const = 0; + virtual const FunctionMetadata& GetFunctionMetadata() const = 0; }; } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD index 84fad2ea8f6..8bb15674db0 100644 --- a/tensorflow/c/experimental/saved_model/core/revived_types/BUILD +++ b/tensorflow/c/experimental/saved_model/core/revived_types/BUILD @@ -58,3 +58,24 @@ cc_library( "//tensorflow/c/eager:immediate_execution_tensor_handle", ], ) + +cc_library( + name = "tf_concrete_function", + srcs = [ + "tf_concrete_function.cc", + ], + hdrs = [ + "tf_concrete_function.h", + ], + deps = [ + ":tensorhandle_convertible", + "//tensorflow/c/eager:immediate_execution_context", + "//tensorflow/c/eager:immediate_execution_operation", + "//tensorflow/c/eager:immediate_execution_tensor_handle", + "//tensorflow/c/experimental/saved_model/core:concrete_function", + "//tensorflow/c/experimental/saved_model/core:function_metadata", + "//tensorflow/core:lib", + "//tensorflow/core:protos_all_cc", + "//tensorflow/core/common_runtime/eager:context", + ], +) diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc new file mode 100644 index 00000000000..aa6f0e7205e --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.cc @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" + +#include +#include + +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/core/common_runtime/eager/context.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/platform/errors.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { + +TFConcreteFunction::TFConcreteFunction( + const std::string& name, + std::vector captures, + FunctionMetadata metadata, ImmediateExecutionContext* ctx) + : name_(name), + captures_(std::move(captures)), + metadata_(std::move(metadata)), + ctx_(ctx) {} + +TFConcreteFunction::~TFConcreteFunction() { + Status status = ctx_->RemoveFunction(name_); + if (!status.ok()) { + LOG(ERROR) << "Failed to remove functiondef " << name_ << ". " + << status.error_message(); + } +} + +Status TFConcreteFunction::Create( + const FunctionDef* function_def, + std::vector captures, + FunctionMetadata metadata, ImmediateExecutionContext* ctx, + std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ctx->AddFunctionDef(*function_def)); + out->reset(new TFConcreteFunction(function_def->signature().name(), + std::move(captures), std::move(metadata), + ctx)); + return Status(); +} + +const std::vector& +TFConcreteFunction::GetCaptures() const { + return captures_; +} + +const FunctionMetadata& TFConcreteFunction::GetFunctionMetadata() const { + return metadata_; +} + +Status TFConcreteFunction::GetCallOp(ImmediateOpPtr* out) { + out->reset(ctx_->CreateOperation()); + // In eager mode, TF2 python executes functions by constructing an op with + // the name of the functiondef: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L545 + // In graph mode, we create a PartitionedCallOp instead: + // https://github.com/tensorflow/tensorflow/blob/66668ec0ca432e2f38a575b814f45b6d299d01ed/tensorflow/python/eager/function.py#L573 + + // TODO(bmzhao): After discussing with Allen, we should execute this via a + // PartitionedCallOp for compatibility with "tooling that assumes functions in + // graphs are PartitionedCallOps". + TF_RETURN_IF_ERROR((*out)->Reset(name_.c_str(), nullptr)); + return Status(); +} + +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h new file mode 100644 index 00000000000..71c8322414d --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h @@ -0,0 +1,87 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ + +#include +#include +#include +#include +#include + +#include "tensorflow/c/eager/immediate_execution_context.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { + +// TF Eager Runtime-based implementation of a "ConcreteFunction" loaded from a +// saved model. +class TFConcreteFunction : public ConcreteFunction { + public: + // Factory function for creating a TFConcreteFunction. + // + // Params: + // function_def - The function_def associated with the created + // TFConcreteFunction. TFConcreteFunction will register this + // function_def with `ctx` on creation, and de-register it on + // destruction. function_def must be non-null, but + // otherwise has no lifetime requirements. + // captures - The captured TensorHandles associated with this + // TFConcreteFunction. + // metadata - The FunctionMetadata associated with this TFConcreteFunction. + // ctx - A handle to the Tensorflow runtime. This MUST be non-null and + // outlive TFConcreteFunction. + // out - The output TFConcreteFunction. + static Status Create(const FunctionDef* function_def, + std::vector captures, + FunctionMetadata metadata, + ImmediateExecutionContext* ctx, + std::unique_ptr* out); + + // This method returns the "Call" Op used to execute the function. + Status GetCallOp(ImmediateOpPtr* out) override; + + const std::vector& GetCaptures() + const override; + + const FunctionMetadata& GetFunctionMetadata() const override; + + ~TFConcreteFunction() override; + + private: + TFConcreteFunction(const std::string& name, + std::vector captures, + FunctionMetadata metadata, ImmediateExecutionContext* ctx); + + TFConcreteFunction(const TFConcreteFunction&) = delete; + TFConcreteFunction& operator=(const TFConcreteFunction&) = delete; + + // Name of the FunctionDef corresponding to this TFConcreteFunction + std::string name_; + std::vector captures_; + FunctionMetadata metadata_; + ImmediateExecutionContext* ctx_; +}; + +} // namespace tensorflow + +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TF_CONCRETE_FUNCTION_H_ diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc index 196420eb537..4b1d7672ed6 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.cc @@ -17,14 +17,125 @@ limitations under the License. #include +#include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/c/tf_tensor_internal.h" #include "tensorflow/core/framework/tensor.pb.h" +#include "tensorflow/core/platform/errors.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" +#include "tensorflow/core/protobuf/struct.pb.h" namespace tensorflow { namespace internal { +namespace { + +// This returns the size of `tf.nest.flatten(value)`, on values that are +// used in tf.function's input_signatures. +int FlattenedSize(const tensorflow::StructuredValue& value, Status* status) { + // This follows the logic from + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L2775 + switch (value.kind_case()) { + case StructuredValue::kDictValue: { + const DictValue& dict = value.dict_value(); + int size = 0; + for (const auto& field : dict.fields()) { + size += FlattenedSize(field.second, status); + } + return size; + } + case StructuredValue::kTupleValue: { + const TupleValue& tuple = value.tuple_value(); + int size = 0; + for (const StructuredValue& value : tuple.values()) { + size += FlattenedSize(value, status); + } + return size; + } + case StructuredValue::kListValue: { + const ListValue& list = value.list_value(); + int size = 0; + for (const StructuredValue& value : list.values()) { + size += FlattenedSize(value, status); + } + return size; + } + case StructuredValue::kTensorSpecValue: { + return 1; + } + case StructuredValue::kNoneValue: { + // Base case: do nothing. + // This arises, for example, as the top-level object of an output + // signature when there are no return values. + return 0; + } + default: { + status->Update(errors::Internal("Unhandled structured value kind ", + value.kind_case())); + return 0; + } + } +} + +// Perform some basic sanity checks on SavedConcreteFunction's input and +// output signatures with respect to the corresponding FunctionDef's input +// and output args. +Status ValidateSavedFunctionCompatibleWithFunctionDef( + const SavedConcreteFunction& saved_concrete_function, + const FunctionDef* function_def) { + // tf.functions go through many transformations before becoming FunctionDefs + // 1. flatten user-provided inputs: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2671-L2675 + // 2. convert user-provided inputs to tensors: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L2687-L2688 + // 3. filter any non-tensor, non-variable inputs: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1840-L1841 + // 4. concatenate any captured inputs: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1912 + + // Since our API is limited to tf.functions annotated with input signatures, + // conditions 2 and 3 are trivially satisfied. + // We need to ensure that: + // flatten(input_signature).size() + captures.size() = fdef.signature().size() + // A concrete function's serialized "canonicalized_input_signature" comes + // from encoding its "structured_input_signature" field: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/saved_model/function_serialization.py#L70-L71 + // The "structured_input_signature" is guaranteed to be a tuple of the python + // args, kwargs that correspond to the tf.function: + // https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/eager/function.py#L1974-L1979 + + const std::string& name = function_def->signature().name(); + const StructuredValue& input_signature = + saved_concrete_function.canonicalized_input_signature(); + Status status; + int input_signature_size = FlattenedSize(input_signature, &status); + TF_RETURN_IF_ERROR(status); + if (input_signature_size + saved_concrete_function.bound_inputs_size() != + function_def->signature().input_arg_size()) { + return errors::FailedPrecondition( + "FunctionDef ", name, " has ", + function_def->signature().input_arg_size(), + " inputs, but the SavedConcreteFunction has ", input_signature_size, + " flattened user inputs and ", + saved_concrete_function.bound_inputs_size(), " captured inputs."); + } + + const StructuredValue& output_signature = + saved_concrete_function.output_signature(); + int output_signature_size = FlattenedSize(output_signature, &status); + TF_RETURN_IF_ERROR(status); + if (output_signature_size != function_def->signature().output_arg_size()) { + return errors::FailedPrecondition( + "FunctionDef ", name, " has ", + function_def->signature().output_arg_size(), + " outputs, but the SavedConcreteFunction has ", output_signature_size, + " flattened outputs."); + } + + return status; +} + +} // namespace Status TensorProtoToConstant(ImmediateExecutionContext* ctx, const TensorProto& proto, @@ -54,5 +165,31 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, return Status(); } +Status LoadTFConcreteFunction( + const SavedConcreteFunction& saved_concrete_function, + const FunctionDef* function_def, + const std::unordered_map>& + captured_objects, + ImmediateExecutionContext* ctx, std::unique_ptr* out) { + TF_RETURN_IF_ERROR(ValidateSavedFunctionCompatibleWithFunctionDef( + saved_concrete_function, function_def)); + + // Copy over captures + std::vector captures; + captures.reserve(saved_concrete_function.bound_inputs_size()); + for (int bound_input : saved_concrete_function.bound_inputs()) { + auto iter = captured_objects.find(bound_input); + if (iter == captured_objects.end()) { + return errors::FailedPrecondition("Failed to find bound_input ", + bound_input, + " for SavedConcreteFunction"); + } + captures.push_back(iter->second->handle()); + } + + return TFConcreteFunction::Create(function_def, std::move(captures), {}, ctx, + out); +} + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h index ab1531709e4..89a959a89d4 100644 --- a/tensorflow/c/experimental/saved_model/core/saved_model_utils.h +++ b/tensorflow/c/experimental/saved_model/core/saved_model_utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "tensorflow/c/eager/immediate_execution_context.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/revived_types/variable.h" #include "tensorflow/core/framework/tensor.pb.h" #include "tensorflow/core/protobuf/saved_object_graph.pb.h" @@ -43,6 +44,14 @@ Status LoadSavedVariable(ImmediateExecutionContext* ctx, const SavedVariable& variable, std::unique_ptr* output); +// Creates a TFConcreteFunction from a SavedConcreteFunction. +Status LoadTFConcreteFunction( + const SavedConcreteFunction& saved_concrete_function, + const FunctionDef* function_def, + const std::unordered_map>& + captured_objects, + ImmediateExecutionContext* ctx, std::unique_ptr* out); + } // namespace internal } // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc new file mode 100644 index 00000000000..05fbac13077 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_loading_test.cc @@ -0,0 +1,271 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include +#include + +#include "tensorflow/c/eager/immediate_execution_tensor_handle.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h" +#include "tensorflow/c/experimental/saved_model/core/revived_types/tf_concrete_function.h" +#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h" +#include "tensorflow/c/experimental/saved_model/core/test_utils.h" +#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" +#include "tensorflow/core/framework/function.pb.h" +#include "tensorflow/core/framework/op_def.pb.h" +#include "tensorflow/core/framework/types.pb.h" +#include "tensorflow/core/lib/core/status_test_util.h" +#include "tensorflow/core/platform/status.h" +#include "tensorflow/core/platform/test.h" +#include "tensorflow/core/protobuf/error_codes.pb.h" +#include "tensorflow/core/protobuf/saved_object_graph.pb.h" + +namespace tensorflow { +namespace { + +class SavedConcreteFunctionLoadingTest : public ::testing::Test { + public: + SavedConcreteFunctionLoadingTest() + : device_mgr_(testing::CreateTestingDeviceMgr()), + ctx_(testing::CreateTestingEagerContext(device_mgr_.get())) {} + + EagerContext* context() { return ctx_.get(); } + + private: + std::unique_ptr device_mgr_; + EagerContextPtr ctx_; +}; + +class DummyCapture : public TensorHandleConvertible { + public: + DummyCapture(ImmediateExecutionContext* ctx, int8 value) + : TensorHandleConvertible( + testing::CreateTensorHandle(ctx, DT_FLOAT, {2, 4}, value)) {} +}; + +FunctionDef FuncDefWithNumInputsOutputs(int num_inputs, int num_outputs) { + FunctionDef func; + OpDef* signature = func.mutable_signature(); + for (int i = 0; i < num_inputs; ++i) { + signature->add_input_arg(); + } + for (int i = 0; i < num_outputs; ++i) { + signature->add_output_arg(); + } + return func; +} + +// A SavedConcreteFunction whose canonicalized input signature +// has less inputs than its corresponding FunctionDef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, TooFewInputsInSavedConcreteFunction) { + // `saved` has 1 input + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + + // `func` has 2 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(2, 0); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose canonicalized input signature length + +// captures is less than its corresponding FunctionDef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, + TooFewInputsWithCapturesInSavedConcreteFunction) { + // `saved` has 1 input, and 1 capture, for a total of 2 inputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + saved.add_bound_inputs(5); + + // `func` has 3 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(3, 0); + + std::unordered_map> captures; + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose canonicalized input signature +// has more inputs than its corresponding FunctionDef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, TooManyInputsInSavedConcreteFunction) { + // `saved` has 3 inputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ThreeArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + + // `func` has 2 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(2, 0); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose canonicalized input signature +// has the same number of inputs than its corresponding FunctionDef, but has +// additional captures should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, + TooManyInputsWithCaptureInSavedConcreteFunction) { + // `saved` has 3 inputs, and 1 capture, for a total of 4 inputs. + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ThreeArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + saved.add_bound_inputs(5); + + // `func` has 3 inputs. + FunctionDef func = FuncDefWithNumInputsOutputs(3, 0); + + std::unordered_map> captures; + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose capture refers to an index not in the capture +// map should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, ImproperCaptureIndex) { + // `saved` has 3 inputs, 1 capture, for a total of 4 inputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ThreeArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + // Capture is at index "10" + saved.add_bound_inputs(10); + + // `func` has 4 inputs + FunctionDef func = FuncDefWithNumInputsOutputs(4, 0); + + // `captures` only has a capture for index 5 + std::unordered_map> captures; + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose outputs are fewer than its corresponding +// functiondef should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, TooFewOutputsInSavedConcreteFunction) { + // `saved` has 0 inputs, 1 output + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ZeroArgInputSignature(); + *saved.mutable_output_signature() = testing::SingleReturnOutputSignature(); + + // `func` has 0 inputs, 2 outputs + FunctionDef func = FuncDefWithNumInputsOutputs(0, 2); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose outputs exceed its corresponding functiondef +// should cause an error. +TEST_F(SavedConcreteFunctionLoadingTest, + TooManyOutputsInSavedConcreteFunction) { + // `saved` has 1 input, 3 outputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature(); + + // `func` has 1 input, 2 outputs + FunctionDef func = FuncDefWithNumInputsOutputs(1, 2); + + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + EXPECT_EQ(status.code(), error::FAILED_PRECONDITION) + << status.error_message(); +} + +// A SavedConcreteFunction whose (inputs + captures) = functiondef inputs, +// and whose outputs = functiondef outputs should successfully load. +TEST_F(SavedConcreteFunctionLoadingTest, SuccessfulLoad) { + // `saved` has 1 input, 2 captures, 3 outputs + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::SingleArgInputSignature(); + *saved.mutable_output_signature() = testing::ThreeReturnOutputSignature(); + saved.add_bound_inputs(2); + saved.add_bound_inputs(5); + + // `func` has 3 inputs, 3 outputs + FunctionDef func = FuncDefWithNumInputsOutputs(3, 3); + + std::unordered_map> captures; + captures[2] = std::make_unique(context(), 1); + captures[5] = std::make_unique(context(), 10); + + std::unique_ptr result; + Status status = internal::LoadTFConcreteFunction(saved, &func, captures, + context(), &result); + TF_EXPECT_OK(status) << status.error_message(); +} + +// A TFConcreteFunction should register functiondefs on creation, and +// remove them upon deletion. +TEST_F(SavedConcreteFunctionLoadingTest, RegistersAndRemovesFunctionDefs) { + std::string func_name = "FooBarBazWombatFunction"; + + SavedConcreteFunction saved; + *saved.mutable_canonicalized_input_signature() = + testing::ZeroArgInputSignature(); + *saved.mutable_output_signature() = testing::ZeroReturnOutputSignature(); + FunctionDef func = FuncDefWithNumInputsOutputs(0, 0); + *func.mutable_signature()->mutable_name() = func_name; + + { + std::unique_ptr result; + Status status = + internal::LoadTFConcreteFunction(saved, &func, {}, context(), &result); + TF_EXPECT_OK(status) << status.error_message(); + // The function should be registered with context. + EXPECT_TRUE(context()->FindFunctionByName(func_name)); + } + + // After `result's` destructor runs, the function should no longer be + // registered with context. + EXPECT_FALSE(context()->FindFunctionByName(func_name)); +} + +} // namespace +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc new file mode 100644 index 00000000000..6250af6dba1 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.cc @@ -0,0 +1,213 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h" + +#include + +#include "absl/strings/string_view.h" +#include "tensorflow/core/platform/logging.h" +#include "tensorflow/core/platform/protobuf.h" +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace testing { +namespace { + +constexpr absl::string_view kZeroArgInputSignatureTextProto = R"( +tuple_value: { + values: { + tuple_value: { + } + } + values: { + dict_value: { + } + } +} +)"; + +constexpr absl::string_view kSingleArgInputSignatureTextProto = R"( +tuple_value: { + values: { + tuple_value: { + values: { + tensor_spec_value: { + name : "x" + shape: { + dim: { + size: 1 + } + dim: { + size: 10 + } + } + dtype: DT_FLOAT + } + } + } + } + values: { + dict_value: { + } + } +} +)"; + +constexpr absl::string_view kThreeArgInputSignatureTextProto = R"( +tuple_value: { + values: { + tuple_value: { + values: { + tensor_spec_value: { + name : "x" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + values: { + tensor_spec_value: { + name : "y" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + values: { + tensor_spec_value: { + name : "z" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + } + } + values: { + dict_value: { + } + } +} + +)"; + +constexpr absl::string_view kZeroReturnOutputSignatureTextProto = R"( +none_value: {} +)"; + +constexpr absl::string_view kSingleReturnOutputSignatureTextProto = R"( +tensor_spec_value: { + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT +} +)"; + +constexpr absl::string_view kThreeReturnOutputSignatureTextProto = R"( +tuple_value: { + values: { + dict_value: { + fields: { + key : "a" + value: { + tensor_spec_value: { + name : "0/a" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + } + fields: { + key : "b" + value: { + tensor_spec_value: { + name : "0/b" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } + } + } + } + values: { + tensor_spec_value: { + name : "1" + shape: { + dim: { + size: 1 + } + } + dtype: DT_FLOAT + } + } +} +)"; + +StructuredValue ParseStructuredValue(absl::string_view text_proto) { + StructuredValue value; + CHECK(tensorflow::protobuf::TextFormat::ParseFromString(string(text_proto), + &value)); + return value; +} + +} // namespace + +StructuredValue ZeroArgInputSignature() { + return ParseStructuredValue(kZeroArgInputSignatureTextProto); +} + +StructuredValue SingleArgInputSignature() { + return ParseStructuredValue(kSingleArgInputSignatureTextProto); +} + +StructuredValue ThreeArgInputSignature() { + return ParseStructuredValue(kThreeArgInputSignatureTextProto); +} + +StructuredValue ZeroReturnOutputSignature() { + return ParseStructuredValue(kZeroReturnOutputSignatureTextProto); +} + +StructuredValue SingleReturnOutputSignature() { + return ParseStructuredValue(kSingleReturnOutputSignatureTextProto); +} + +StructuredValue ThreeReturnOutputSignature() { + return ParseStructuredValue(kThreeReturnOutputSignatureTextProto); +} + +} // namespace testing +} // namespace tensorflow diff --git a/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h new file mode 100644 index 00000000000..8aa7d5694e1 --- /dev/null +++ b/tensorflow/c/experimental/saved_model/core/tf_concrete_function_test_protos.h @@ -0,0 +1,50 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ +#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ + +#include "tensorflow/core/protobuf/struct.pb.h" + +namespace tensorflow { +namespace testing { + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 0 inputs +StructuredValue ZeroArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 1 input +StructuredValue SingleArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized InputSignature of a +// tf.function with 3 inputs +StructuredValue ThreeArgInputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with no return values +StructuredValue ZeroReturnOutputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with a single tensor output +StructuredValue SingleReturnOutputSignature(); + +// Returns a StructuredValue corresponding to the serialized OutputSignature of +// a tf.function with three tensor outputs +StructuredValue ThreeReturnOutputSignature(); + +} // namespace testing +} // namespace tensorflow +#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_TF_CONCRETE_FUNCTION_TEST_PROTOS_H_ diff --git a/tensorflow/c/experimental/saved_model/internal/BUILD b/tensorflow/c/experimental/saved_model/internal/BUILD index 888c284bb12..6be2a02eeb8 100644 --- a/tensorflow/c/experimental/saved_model/internal/BUILD +++ b/tensorflow/c/experimental/saved_model/internal/BUILD @@ -41,11 +41,13 @@ cc_library( ":tensorhandle_list", ":tensorhandle_list_type", "//tensorflow/c:c_api_macros", + "//tensorflow/c:tf_status_internal", "//tensorflow/c/eager:c_api", - "//tensorflow/c/eager:c_api_internal", + "//tensorflow/c/eager:immediate_execution_operation", "//tensorflow/c/eager:tfe_op_internal", "//tensorflow/c/experimental/saved_model/core:concrete_function", "//tensorflow/c/experimental/saved_model/core:function_metadata", + "//tensorflow/core:lib", ], ) @@ -205,9 +207,13 @@ tf_cc_test( ], deps = [ "//tensorflow/c:tf_status", + "//tensorflow/c:tf_tensor", "//tensorflow/c/eager:c_api", "//tensorflow/c/eager:c_api_experimental", + "//tensorflow/c/eager:c_api_test_util", + "//tensorflow/c/experimental/saved_model/public:concrete_function", "//tensorflow/c/experimental/saved_model/public:saved_model_api", + "//tensorflow/c/experimental/saved_model/public:tensorhandle_list", "//tensorflow/core:lib", "//tensorflow/core:test", "//tensorflow/core:test_main", diff --git a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc index dd54416ddf9..12d49212a88 100644 --- a/tensorflow/c/experimental/saved_model/internal/concrete_function.cc +++ b/tensorflow/c/experimental/saved_model/internal/concrete_function.cc @@ -15,12 +15,15 @@ limitations under the License. #include "tensorflow/c/experimental/saved_model/public/concrete_function.h" +#include "tensorflow/c/eager/immediate_execution_operation.h" #include "tensorflow/c/eager/tfe_op_internal.h" #include "tensorflow/c/experimental/saved_model/core/concrete_function.h" #include "tensorflow/c/experimental/saved_model/core/function_metadata.h" #include "tensorflow/c/experimental/saved_model/internal/concrete_function_type.h" #include "tensorflow/c/experimental/saved_model/internal/function_metadata_type.h" #include "tensorflow/c/experimental/saved_model/internal/tensorhandle_list_type.h" +#include "tensorflow/c/tf_status_internal.h" +#include "tensorflow/core/platform/status.h" extern "C" { @@ -34,8 +37,11 @@ const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( return tensorflow::wrap(&tensorflow::unwrap(func)->GetCaptures()); } -TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func) { - return tensorflow::wrap(tensorflow::unwrap(func)->GetCallOp()); +TFE_Op* TF_ConcreteFunctionGetCallOp(TF_ConcreteFunction* func, + TF_Status* status) { + tensorflow::ImmediateOpPtr call_op(nullptr); + status->status = tensorflow::unwrap(func)->GetCallOp(&call_op); + return tensorflow::wrap(call_op.release()); } } // end extern "C" diff --git a/tensorflow/c/experimental/saved_model/public/concrete_function.h b/tensorflow/c/experimental/saved_model/public/concrete_function.h index 2a87214270c..944ddecea16 100644 --- a/tensorflow/c/experimental/saved_model/public/concrete_function.h +++ b/tensorflow/c/experimental/saved_model/public/concrete_function.h @@ -41,7 +41,7 @@ TF_CAPI_EXPORT extern const TF_TensorHandleList* TF_ConcreteFunctionGetCaptures( // Returns a TFE_Op suitable for executing this function. TF_CAPI_EXPORT extern TFE_Op* TF_ConcreteFunctionGetCallOp( - TF_ConcreteFunction* func); + TF_ConcreteFunction* func, TF_Status* status); #ifdef __cplusplus } // end extern "C" From 91c5b330e89d0baeb77bc7d2a599d0502b457b16 Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Wed, 8 Jul 2020 00:23:49 -0700 Subject: [PATCH 28/88] Integrate LLVM at https://github.com/llvm/llvm-project/commit/f54d0e36be6a PiperOrigin-RevId: 320138569 Change-Id: Ia0f4549d97b05a993eaf6aab7ce700f205106d5a --- tensorflow/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index d1c2ec41a94..26a8e8271a8 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -710,8 +710,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "8691544a276744474ff04b71d7e220069435c7fe" - LLVM_SHA256 = "125833431ed7989b50ad05322d1ec5d5470d6c1f5656d7d115fd44289904f634" + LLVM_COMMIT = "f54d0e36be6a4d5dab67244e85b8664282dcf5d1" + LLVM_SHA256 = "df115acb6b5b1a5e2f49819a0b6ae8cf47ccecb61a10d75f85a67198ff06420e" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), From a1cae5e5acf90572f56e8f00fa8dad5008747fe0 Mon Sep 17 00:00:00 2001 From: Anirudh Sriram Date: Wed, 8 Jul 2020 00:53:37 -0700 Subject: [PATCH 29/88] Updates to Profiler Guide PiperOrigin-RevId: 320141724 Change-Id: Ic3033d235884d32cdf3c92df9068fb4415d45e53 --- tensorflow/python/profiler/profiler_v2.py | 40 ++++++++++++----------- 1 file changed, 21 insertions(+), 19 deletions(-) diff --git a/tensorflow/python/profiler/profiler_v2.py b/tensorflow/python/profiler/profiler_v2.py index cacf38069ff..bcba9d52d23 100644 --- a/tensorflow/python/profiler/profiler_v2.py +++ b/tensorflow/python/profiler/profiler_v2.py @@ -51,19 +51,18 @@ class ProfilerOptions( collections.namedtuple( 'ProfilerOptions', ['host_tracer_level', 'python_tracer_level', 'device_tracer_level'])): - """Options to control profiler behaviors. + """Options for finer control over the profiler. - A `tf.profiler.ProfilerOptions` hold the knobs to control tf.profiler's + Use `tf.profiler.ProfilerOptions` to control `tf.profiler` behavior. Fields: - host_tracer_level: for adjust TraceMe levels. i.e. 1 => critical, - 2 => info, 3 => verbose. [default to 2] - python_tracer_level: for enable python function call tracing, 1 => enable. - 0 => disable [default to 0] - device_tracer_level: for adjust device (TPU/GPU) tracer level, 0 => disable - 1 => enabled. We may introduce fine-tuned level in the - future. [default to 1] + host_tracer_level: Adjust CPU tracing level. Values are: 1 - critical info + only, 2 - info, 3 - verbose. [default value is 2] + python_tracer_level: Toggle tracing of Python function calls. Values are: 1 + - enabled, 0 - disabled [default value is 0] + device_tracer_level: Adjust device (TPU/GPU) tracing level. Values are: 1 - + enabled, 0 - disabled [default value is 1] """ def __new__(cls, @@ -77,26 +76,29 @@ class ProfilerOptions( @tf_export('profiler.experimental.start', v1=[]) def start(logdir, options=None): - """Starts profiling. + """Start profiling TensorFlow performance. Args: - logdir: A log directory read by TensorBoard to export the profile results. - options: namedtuple of ProfilerOptions for miscellaneous profiler options. + logdir: Profiling results log directory. + options: `ProfilerOptions` namedtuple to specify miscellaneous profiler + options. See example usage below. Raises: - AlreadyExistsError: If another profiling session is running. + AlreadyExistsError: If a profiling session is already running. Example usage: ```python - tf.profiler.experimental.start( - 'logdir_path', tf.profiler.ProfilerOptions(host_tracer_level=2)) - # do your training here. + options = tf.profiler.experimental.ProfilerOptions(host_tracer_level = 3, + python_tracer_level = 1, + device_tracer_level = 1) + tf.profiler.experimental.start('logdir_path', options = options) + # Training code here tf.profiler.experimental.stop() ``` - Launch TensorBoard and point it to the same logdir you provided to this API. - $ tensorboard --logdir=logdir_path - Open your browser and go to localhost:6006/#profile to view profiling results. + To view the profiling results, launch TensorBoard and point it to `logdir`. + Open your browser and go to `localhost:6006/#profile` to view profiling + results. """ global _profiler From efd35d70d22d370bf4a997cbf53a8030031a48da Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 8 Jul 2020 01:43:30 -0700 Subject: [PATCH 30/88] [MLIR] Convert FuncOp signature with unranked types in HLO->LHLO conversion. PiperOrigin-RevId: 320146856 Change-Id: Ic534e97b2eecbd4573b91ff48ef90d38bbacd9a4 --- .../lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc | 7 +++---- .../mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir | 8 ++++++++ 2 files changed, 11 insertions(+), 4 deletions(-) create mode 100644 tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 4d9cb62a945..e35be3179d2 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -391,16 +391,15 @@ struct HloLegalizeToLhlo target.addIllegalDialect(); BufferAssignmentTypeConverter converter; + auto isMemRefType = [](Type type) { return type.isa(); }; target.addDynamicallyLegalOp([&](FuncOp op) { auto inputs = op.getType().getInputs(); - return llvm::all_of(inputs, - [](Type input) { return input.isa(); }) && + return llvm::all_of(inputs, isMemRefType) && converter.isLegal(&op.getBody()); }); target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { return std::all_of(returnOp.operand_type_begin(), - returnOp.operand_type_end(), - [](Type type) { return type.isa(); }); + returnOp.operand_type_end(), isMemRefType); }); auto module = getOperation(); diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir new file mode 100644 index 00000000000..063716a539b --- /dev/null +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -0,0 +1,8 @@ +// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo=results-escape-function=true -buffer-placement %s -o - | FileCheck %s + +// CHECK-LABEL: func @func_op_unranked_arg_result +func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { + return %arg0 : tensor<*xf32> +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> +// CHECK-NEXT: return [[ARG]] : memref<*xf32> From d6060e0263968998e74bc625d68231bbb1277cab Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 02:01:37 -0700 Subject: [PATCH 31/88] Update GraphDef version to 456. PiperOrigin-RevId: 320148492 Change-Id: Ie8da7f170556efa6a973601300079f6916785101 --- tensorflow/core/public/version.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/core/public/version.h b/tensorflow/core/public/version.h index 42cac007d44..e32e6a138c8 100644 --- a/tensorflow/core/public/version.h +++ b/tensorflow/core/public/version.h @@ -108,7 +108,7 @@ limitations under the License. #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 -#define TF_GRAPH_DEF_VERSION 455 // Updated: 2020/7/7 +#define TF_GRAPH_DEF_VERSION 456 // Updated: 2020/7/8 // Checkpoint compatibility versions (the versions field in SavedSliceMeta). // From f91f6e834755610bdef73b140b8ae9132e6298d1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 02:01:37 -0700 Subject: [PATCH 32/88] compat: Update forward compatibility horizon to 2020-07-08 PiperOrigin-RevId: 320148494 Change-Id: Id05e21e7aa78eaf55e9300b411c26fc0cb0573a8 --- tensorflow/python/compat/compat.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/python/compat/compat.py b/tensorflow/python/compat/compat.py index 7b1472291cd..05a7082124f 100644 --- a/tensorflow/python/compat/compat.py +++ b/tensorflow/python/compat/compat.py @@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export # This value changes every day with an automatic CL. It can be modified in code # via `forward_compatibility_horizon()` or with the environment variable # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. -_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 7, 7) +_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 7, 8) _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DATE_NUMBER = None From b76ea0e912af6f7e4231ce63e4c13fd9f8405346 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 8 Jul 2020 02:11:30 -0700 Subject: [PATCH 33/88] [MLIR][LHLO] Convert mhlo.dynamic_reshape -> lhlo.reshape_memref_cast. PiperOrigin-RevId: 320149593 Change-Id: Ie85979a26225dbf017ad4bb888de1548e55a6d7f --- .../mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 8 ------ .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 26 +++++++++++++++++++ .../tests/hlo-legalize-to-lhlo-unranked.mlir | 26 +++++++++++++++++++ 3 files changed, 52 insertions(+), 8 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 5a8c3ccd4a4..46561bb8a03 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -470,14 +470,6 @@ def ReshapeMemRefCastOp: Op]; - let extraClassDeclaration = [{ MemRefType getType() { return getResult().getType().cast(); } }]; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index e35be3179d2..e7746f09eef 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -220,6 +220,31 @@ struct HloToLhloDynamicBroadcastInDimOpConverter } }; +struct HloToLhloDynamicReshapeConverter + : public BaseOpConversion { + public: + using BaseOpConversion::BaseOpConversion; + + LogicalResult matchAndRewrite( + mhlo::DynamicReshapeOp op, ArrayRef operands, + ConversionPatternRewriter& rewriter) const final { + Type result_type; + if (auto ranked_type = op.getType().dyn_cast()) { + result_type = + MemRefType::get(ranked_type.getShape(), ranked_type.getElementType()); + } else if (auto unranked_type = + op.getType().dyn_cast()) { + result_type = UnrankedMemRefType::get(unranked_type.getElementType(), 0); + } else { + return failure(); + } + mhlo::DynamicReshapeOp::Adaptor adaptor(operands); + rewriter.replaceOpWithNewOp( + op, result_type, adaptor.operand(), adaptor.output_shape()); + return success(); + } +}; + struct HloToLhloReduceOpConverter : public BaseOpConversion { public: using BaseOpConversion::BaseOpConversion; @@ -441,6 +466,7 @@ void populateHLOToLHLOConversionPattern( // clang-format off patterns->insert< HloToLhloDynamicBroadcastInDimOpConverter, + HloToLhloDynamicReshapeConverter, HloToLhloOpConverter, HloToLhloOpConverter, HloToLhloOpConverter, diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir index 063716a539b..cc60217be65 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo-unranked.mlir @@ -6,3 +6,29 @@ func @func_op_unranked_arg_result(%arg0: tensor<*xf32>) -> tensor<*xf32> { } // CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>) -> memref<*xf32> // CHECK-NEXT: return [[ARG]] : memref<*xf32> + +// ----- + +// CHECK-LABEL: func @dynamic_reshape_from_unranked +func @dynamic_reshape_from_unranked( + %operand: tensor<*xf32>, %shape: tensor<1xi32>) -> tensor { + %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) + : (tensor<*xf32>, tensor<1xi32>) -> tensor + return %reshaped : tensor +} +// CHECK-SAME: ([[ARG:%.*]]: memref<*xf32>, [[SHAPE:%.*]]: memref<1xi32>) +// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref + +// ----- + +// CHECK-LABEL: func @dynamic_reshape_to_unranked +func @dynamic_reshape_to_unranked( + %operand: tensor, %shape: tensor) -> tensor<*xf32> { + %reshaped = "mhlo.dynamic_reshape"(%operand, %shape) + : (tensor, tensor) -> tensor<*xf32> + return %reshaped : tensor<*xf32> +} +// CHECK-SAME: ([[ARG:%.*]]: memref, [[SHAPE:%.*]]: memref) +// CHECK-NEXT: reshape_memref_cast [[ARG]]([[SHAPE]]) +// CHECK-SAME: : (memref, memref) -> memref<*xf32> From 375484f85a5ac9798a60802f4c4c46cb749ce289 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 04:20:55 -0700 Subject: [PATCH 34/88] Add error message for outfeed shape mismatch. This error is triggered when the XLA program outfeeds data with a shape that is not compatible with the shape expected by the runtime. This error can be triggered by the end user. Currently, the error is not very helpful. This change adds the two mismatched shapes to the error message, in the spirit of the error message reported on CPU for the same error. (see tensorflow/compiler/xla/service/cpu/cpu_runtime.cc line 289) PiperOrigin-RevId: 320162796 Change-Id: I723e70f740a2b0734b7be575ea20d5bed19f55e7 --- tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc index 20a8d3b13a9..25ab1b54f07 100644 --- a/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc +++ b/tensorflow/compiler/xla/service/gpu/outfeed_thunk.cc @@ -45,7 +45,11 @@ Status OutfeedThunk::ExecuteOnStream(const ExecuteParams& params) { return Status::OK(); } CHECK(ShapeUtil::Compatible(hlo_instruction()->operand(0)->shape(), - outfeed_buffers->shape())); + outfeed_buffers->shape())) + << "XLA program outfeed request of shape " + << hlo_instruction()->operand(0)->shape().ToString() + << " did not match the runtime's outfeed buffer of shape " + << outfeed_buffers->shape().ToString(); TF_RETURN_IF_ERROR(outfeed_buffers->ForEachMutableElementWithStatus( [&](const ShapeIndex& index, std::unique_ptr* buffer) { From 4852acce4ba4ff880e3617e193f50dcd22cbecc2 Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 8 Jul 2020 05:59:45 -0700 Subject: [PATCH 35/88] [MLIR][LHLO] Legalize CallOp that call funcs with tensor args/results. PiperOrigin-RevId: 320172723 Change-Id: Ib81a360aa8e85b778614e0a0e2aee5cf8947ad39 --- .../Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc | 12 +++++++++--- .../mlir/hlo/tests/hlo-legalize-to-lhlo.mlir | 5 +++-- 2 files changed, 12 insertions(+), 5 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index e7746f09eef..1162b0ecb6b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -422,9 +422,15 @@ struct HloLegalizeToLhlo return llvm::all_of(inputs, isMemRefType) && converter.isLegal(&op.getBody()); }); - target.addDynamicallyLegalOp([&](mlir::ReturnOp returnOp) { - return std::all_of(returnOp.operand_type_begin(), - returnOp.operand_type_end(), isMemRefType); + target.addDynamicallyLegalOp([&](CallOp op) { + return std::all_of(op.operand_type_begin(), op.operand_type_end(), + isMemRefType) && + std::all_of(op.result_type_begin(), op.result_type_end(), + isMemRefType); + }); + target.addDynamicallyLegalOp([&](mlir::ReturnOp op) { + return std::all_of(op.operand_type_begin(), op.operand_type_end(), + isMemRefType); }); auto module = getOperation(); diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index 0db595c4386..a5559357bdc 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -169,11 +169,12 @@ func @external_func() -> tensor<3xi64> func @dyn_broadcast(%operand: memref) { // BOTH-SAME: (%[[OPERAND:.*]]: memref) %tensor_operand = tensor_load %operand : memref - %shape = call @external_func() : () -> tensor<3xi64> + %c1 = constant 1 : i64 + %shape = tensor_from_elements(%c1, %c1, %c1) : tensor<3xi64> %tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) { broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> } : (tensor, tensor<3xi64>) -> tensor - // BOTH: %[[SHAPE:.*]] = call @external_func() + // BOTH: %[[SHAPE:.*]] = tensor_from_elements // BOTH: %[[C0:.*]] = constant 0 : index // BOTH: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64> // BOTH: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index From 6f914d11db1ee4ba3e3f226e766f1a0bfe37da23 Mon Sep 17 00:00:00 2001 From: Benjamin Kramer Date: Wed, 8 Jul 2020 06:29:16 -0700 Subject: [PATCH 36/88] [XLA:CPU] Teach dot_op_emitter how to tile&vectorize linalg matmuls This is on-par with the existing emitter, sometimes better and unlocks more potential. The strategy classes are duplicated right now, but I expect them to graduate to mlir core soon. I'm planning to remove the custom LLVM IR emitters if this turns out to be stable enough. PiperOrigin-RevId: 320175958 Change-Id: I2df3d09c40041f8a69dac1cb45e945af203ec6e1 --- tensorflow/compiler/xla/service/cpu/BUILD | 26 +- .../xla/service/cpu/dot_op_emitter.cc | 38 +++ .../compiler/xla/service/cpu/mlir_emitter.cc | 8 +- .../cpu/mlir_matmul_codegen_strategy.cc | 269 ++++++++++++++++++ .../cpu/mlir_matmul_codegen_strategy.h | 188 ++++++++++++ .../xla/service/cpu/target_machine_features.h | 12 + .../cpu/target_machine_features_fake.h | 4 + 7 files changed, 539 insertions(+), 6 deletions(-) create mode 100644 tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc create mode 100644 tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h diff --git a/tensorflow/compiler/xla/service/cpu/BUILD b/tensorflow/compiler/xla/service/cpu/BUILD index 102753b882f..b9e10bfb083 100644 --- a/tensorflow/compiler/xla/service/cpu/BUILD +++ b/tensorflow/compiler/xla/service/cpu/BUILD @@ -471,6 +471,7 @@ cc_library( ":cpu_runtime", ":ir_emission_utils", ":mlir_emitter", + ":mlir_matmul_codegen_strategy", ":target_machine_features", ":tiled_dot_emitter", ":vector_support_library", @@ -1102,12 +1103,33 @@ cc_library( "@llvm-project//llvm:Core", "@llvm-project//llvm:IPO", "@llvm-project//llvm:Linker", + "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", - "@llvm-project//mlir:LLVMTransforms", - "@llvm-project//mlir:LinalgToLLVM", "@llvm-project//mlir:LinalgTransforms", "@llvm-project//mlir:Pass", "@llvm-project//mlir:TargetLLVMIR", + "@llvm-project//mlir:Transforms", "@llvm-project//mlir:VectorToLLVM", ], ) + +cc_library( + name = "mlir_matmul_codegen_strategy", + srcs = ["mlir_matmul_codegen_strategy.cc"], + hdrs = ["mlir_matmul_codegen_strategy.h"], + deps = [ + "@llvm-project//llvm:Support", + "@llvm-project//mlir:Affine", + "@llvm-project//mlir:Analysis", + "@llvm-project//mlir:IR", + "@llvm-project//mlir:LinalgOps", + "@llvm-project//mlir:LinalgTransforms", + "@llvm-project//mlir:Pass", + "@llvm-project//mlir:SCFDialect", + "@llvm-project//mlir:StandardOps", + "@llvm-project//mlir:Support", + "@llvm-project//mlir:Transforms", + "@llvm-project//mlir:VectorOps", + "@llvm-project//mlir:VectorToSCF", + ], +) diff --git a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc index 72f3a4dfac7..ee4bcf4cd35 100644 --- a/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/dot_op_emitter.cc @@ -31,10 +31,12 @@ limitations under the License. #include "mlir/IR/MLIRContext.h" // from @llvm-project #include "mlir/IR/OperationSupport.h" // from @llvm-project #include "mlir/IR/Value.h" // from @llvm-project +#include "tensorflow/compiler/xla/primitive_util.h" #include "tensorflow/compiler/xla/service/cpu/cpu_options.h" #include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h" #include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h" #include "tensorflow/compiler/xla/service/cpu/mlir_emitter.h" +#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" #include "tensorflow/compiler/xla/service/cpu/target_machine_features.h" #include "tensorflow/compiler/xla/service/cpu/tiled_dot_emitter.h" #include "tensorflow/compiler/xla/service/cpu/vector_support_library.h" @@ -202,6 +204,20 @@ class DotOpEmitter { .value_or(kDefaultTileSize); } + std::array GetMlirGemmTileSize() const { + // Tile by 4 x registers x register size. This was picked by running + // small matmuls on Haswell and Skylake. There's a lot of room for + // improvement here. + constexpr int64_t kDefaultTileSizeForM = 4; + int64_t elements_per_register = + target_machine_features_.vector_register_num_elements( + *b_->GetInsertBlock()->getParent(), + dot_info_.result_shape.element_type()); + int64_t num_registers = target_machine_features_.vector_register_count( + *b_->GetInsertBlock()->getParent()); + return {{kDefaultTileSizeForM, num_registers, elements_per_register}}; + } + DotInfo dot_info_; string dot_hlo_name_; const llvm_ir::IrArray& target_array_; @@ -250,6 +266,7 @@ Status DotOpEmitter::EmitLinalgMatmul() { absl::StrCat("linalgMatMul_", dot_info_.result_shape.ToString(true), "_", dot_info_.lhs_shape.ToString(true), "_", dot_info_.rhs_shape.ToString(true)); + return EmitMlirFuncAndCall( mlir_context_, b_, dot_info_.result_shape, operand_shapes, target_ptr, operand_ptrs, name, [&](mlir::OpBuilder* builder, mlir::FuncOp function) { @@ -259,6 +276,27 @@ Status DotOpEmitter::EmitLinalgMatmul() { mlir::edsc::intrinsics::linalg_matmul(mlir::TypeRange{}, mlir::ValueRange{b, c, a}); mlir::edsc::intrinsics::std_ret(); + + mlir::linalg::LinalgTilingOptions tilingOptions; + tilingOptions = tilingOptions.setTileSizes(GetMlirGemmTileSize()); + int64 alignment = + target_machine_features_.minimum_alignment_for_allocation( + ShapeUtil::ByteSizeOf(dot_info_.result_shape)); + mlir_strategy::MatmulCodegenStrategy strategy; + strategy.tile(tilingOptions) + .promote( + mlir::linalg::LinalgPromotionOptions() + .setAlignment(alignment) + .setUseFullTileBuffersByDefault(true) + .setUseAlloca(true)) + .vectorize() + .setVectorTransformsOptions( + mlir::vector::VectorTransformsOptions() + .setVectorTransformsOptions( + mlir::vector::VectorContractLowering::OuterProduct)) + .setVectorTransferToSCFOptions( + mlir::VectorTransferToSCFOptions().setUnroll(true)); + strategy.transform(function); }); } diff --git a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc index e7d52c288d5..d17f4671327 100644 --- a/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc +++ b/tensorflow/compiler/xla/service/cpu/mlir_emitter.cc @@ -17,14 +17,14 @@ limitations under the License. #include "llvm/Linker/Linker.h" #include "llvm/Transforms/IPO/Internalize.h" -#include "mlir/Conversion/LinalgToLLVM/LinalgToLLVM.h" // from @llvm-project -#include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h" // from @llvm-project #include "mlir/Dialect/Linalg/Passes.h" // from @llvm-project #include "mlir/IR/Module.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassManager.h" // from @llvm-project #include "mlir/Target/LLVMIR.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project #include "tensorflow/compiler/mlir/xla/hlo_utils.h" namespace xla { @@ -35,9 +35,9 @@ namespace { std::unique_ptr MakeLLVMModule(mlir::OwningModuleRef module) { mlir::PassManager manager(module->getContext()); manager.addPass(mlir::createConvertLinalgToLoopsPass()); - manager.addPass(mlir::createConvertLinalgToLLVMPass()); + manager.addPass(mlir::createLowerAffinePass()); + manager.addPass(mlir::createLowerToCFGPass()); manager.addPass(mlir::createConvertVectorToLLVMPass()); - manager.addPass(mlir::createLowerToLLVMPass()); CHECK(succeeded(manager.run(*module))); return mlir::translateModuleToLLVMIR(*module); } diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc new file mode 100644 index 00000000000..ea89071a967 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.cc @@ -0,0 +1,269 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h" + +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "mlir/Analysis/SliceAnalysis.h" // from @llvm-project +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project +#include "mlir/Dialect/Affine/IR/AffineOps.h" // from @llvm-project +#include "mlir/Dialect/Linalg/IR/LinalgOps.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Transforms/Hoisting.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Utils/Utils.h" // from @llvm-project +#include "mlir/Dialect/SCF/SCF.h" // from @llvm-project +#include "mlir/Dialect/SCF/Utils.h" // from @llvm-project +#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project +#include "mlir/Dialect/Vector/EDSC/Intrinsics.h" // from @llvm-project +#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project +#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project +#include "mlir/IR/AffineExpr.h" // from @llvm-project +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/BlockAndValueMapping.h" // from @llvm-project +#include "mlir/IR/Dominance.h" // from @llvm-project +#include "mlir/IR/MLIRContext.h" // from @llvm-project +#include "mlir/IR/OperationSupport.h" // from @llvm-project +#include "mlir/IR/PatternMatch.h" // from @llvm-project +#include "mlir/IR/StandardTypes.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassManager.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project +#include "mlir/Transforms/LoopUtils.h" // from @llvm-project +#include "mlir/Transforms/Passes.h" // from @llvm-project + +// TODO(kramerb): Remove this once strategy is in mlir core. + +using namespace mlir; // NOLINT +using namespace mlir::linalg; // NOLINT + +#define DEBUG_TYPE "matmul-codegen-strategy" + +namespace xla { +namespace cpu { +namespace mlir_strategy { + +//===----------------------------------------------------------------------===// +// TODO: Cleanup and upstream these to go into core. Please ignore for now ! +//===----------------------------------------------------------------------===// +static void hoistRedundantCopies(FuncOp func) { + bool changed = true; + while (changed) { + changed = false; + func.walk([&](linalg::FillOp op) { + auto loop = op.getParentOfType(); + if (!loop) return; + + for (auto operand : op.getOperands()) + if (!loop.isDefinedOutsideOfLoop(operand)) return; + + // Hoist fill before. + op.getOperation()->moveBefore(loop); + changed = true; + }); + + func.walk([&](linalg::CopyOp op) { + auto loop = op.getParentOfType(); + if (!loop) return; + + for (auto operand : op.getOperands()) + if (!loop.isDefinedOutsideOfLoop(operand)) return; + + Value sourceView = op.getInput(0); + while (auto subViewOp = sourceView.getDefiningOp()) + sourceView = subViewOp.getViewSource(); + + // Source traces back to a block argument. + if (sourceView.isa()) { + op.getOperation()->moveBefore(loop); + } else { + assert(sourceView.getDefiningOp() || + sourceView.getDefiningOp() || + sourceView.getDefiningOp()); + op.getOperation()->moveAfter(loop); + } + changed = true; + }); + } +} + +/// Substitute scf.for = %lb to %ub step %step by an AffineExpr expressing: +/// `%lb + %step * new_dim` where +/// 1. the AffineExpr for %lb is either an AffineConstantExpr or an +/// AffineDimExpr depending on whether the value is constant or not. +/// 2. the AffineExpr for %step is either an AffineConstantExpr or an +/// AffineSymbolExpr depending on whether the value is constant or not. +/// +static void substitute(scf::ForOp forOp, SmallVectorImpl &exprs, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + MLIRContext *ctx = forOp.getContext(); + auto lbConstant = forOp.lowerBound().getDefiningOp(); + AffineExpr lb = lbConstant ? getAffineConstantExpr(lbConstant.getValue(), ctx) + : getAffineDimExpr(dims.size(), ctx); + + auto stepConstant = forOp.step().getDefiningOp(); + AffineExpr step = stepConstant + ? getAffineConstantExpr(stepConstant.getValue(), ctx) + : getAffineSymbolExpr(symbols.size(), ctx); + + if (!lbConstant) dims.push_back(forOp.lowerBound()); + if (!stepConstant) symbols.push_back(forOp.step()); + exprs.push_back(lb + step * getAffineDimExpr(dims.size(), ctx)); + + auto ubConstant = forOp.upperBound().getDefiningOp(); + AffineExpr ub = ubConstant ? getAffineConstantExpr(ubConstant.getValue(), ctx) + : getAffineDimExpr(dims.size(), ctx); + if (!ubConstant) dims.push_back(forOp.upperBound()); + exprs.push_back(ub); + + dims.push_back(forOp.getInductionVar()); +} + +/// Traverse the . +static void substitute(AffineMinOp minOp, SmallVectorImpl &exprs, + SmallVectorImpl &dims, + SmallVectorImpl &symbols) { + MLIRContext *ctx = minOp.getContext(); + for (Value v : minOp.getDimOperands()) { + if (auto forOp = scf::getForInductionVarOwner(v)) { + substitute(forOp, exprs, dims, symbols); + continue; + } + if (auto parentMinOp = v.getDefiningOp()) { + substitute(parentMinOp, exprs, dims, symbols); + continue; + } + exprs.push_back(getAffineDimExpr(dims.size(), ctx)); + dims.push_back(v); + } +} + +/// Perform folding of chains of AffineMinOp. +struct AffineMinCanonicalizationPattern : public OpRewritePattern { + using OpRewritePattern::OpRewritePattern; + + LogicalResult matchAndRewrite(AffineMinOp minOp, + PatternRewriter &rewriter) const override; +}; + +LogicalResult AffineMinCanonicalizationPattern::matchAndRewrite( + AffineMinOp minOp, PatternRewriter &rewriter) const { + LLVM_DEBUG(llvm::dbgs() << "\nCanonicalize AffineMin: " + << *minOp.getOperation() << "\n"); + + int64_t min = std::numeric_limits::max(); + for (auto e : minOp.map().getResults()) + if (auto cstExpr = e.dyn_cast()) + min = std::min(min, cstExpr.getValue()); + if (min == std::numeric_limits::max()) return failure(); + + SmallVector exprs; + SmallVector dims, symbols; + substitute(minOp, exprs, dims, symbols); + + SmallVector operands = dims; + operands.append(symbols.begin(), symbols.end()); + + MLIRContext *ctx = minOp.getContext(); + auto map = AffineMap::get(dims.size(), symbols.size(), exprs, ctx); + LLVM_DEBUG(llvm::dbgs() << "Substitution map: " << map << "\n"); + + SmallVector modExprs; + for (unsigned idx = 0, e = map.getNumResults(); idx < e; ++idx) + modExprs.push_back(getAffineDimExpr(idx, ctx) % min); + map = AffineMap::get(map.getNumResults(), 0, modExprs, ctx).compose(map); + canonicalizeMapAndOperands(&map, &operands); + map = simplifyAffineMap(map); + + LLVM_DEBUG(llvm::dbgs() << "Post mod: " << map << "\n"; + llvm::interleaveComma(operands, llvm::dbgs())); + + if (!llvm::all_of(map.getResults(), [](AffineExpr e) { + if (auto cst = e.dyn_cast()) + return cst.getValue() == 0; + return false; + })) + return failure(); + + rewriter.replaceOpWithNewOp(minOp, min); + return success(); +} +//===----------------------------------------------------------------------===// +// END TODO +//===----------------------------------------------------------------------===// + +void MatmulCodegenStrategy::transform(FuncOp func) const { + MLIRContext *context = func.getContext(); + // Emplace patterns one at a time while also maintaining a simple chained + // state transition. + unsigned stepCount = 0; + SmallVector stage1Patterns; + auto zeroState = Identifier::get(std::to_string(stepCount), context); + auto currentState = zeroState; + for (auto &t : transformation_sequence) { + auto nextState = Identifier::get(std::to_string(++stepCount), context); + auto marker = (currentState == zeroState) + ? linalg::LinalgMarker({}, nextState) + : linalg::LinalgMarker(currentState, nextState); + stage1Patterns.emplace_back(t->buildRewritePatterns(context, marker)); + currentState = nextState; + } + + OwningRewritePatternList stage2Patterns = + linalg::getLinalgTilingCanonicalizationPatterns(context); + stage2Patterns.insert(context); + + auto stage3Transforms = [](Operation *op) { + // Some of these may be too aggressive as a stage 3 that is applied on each + // stage 1 application and may have to be split out to post staged patterns + // application (in which case they could just be passes, TBD). + PassManager pm(op->getContext()); + pm.addPass(createLoopInvariantCodeMotionPass()); + if (failed(pm.run(op->getParentOfType()))) + llvm_unreachable("Unexpected failure in cleanup pass pipeline."); + promoteSingleIterationLoops(cast(op)); + hoistViewAllocOps(cast(op)); + hoistRedundantVectorTransfers(cast(op)); + hoistRedundantCopies(cast(op)); + return success(); + }; + linalg::applyStagedPatterns(func, stage1Patterns, stage2Patterns, + stage3Transforms); + + //===--------------------------------------------------------------------===// + // Post staged patterns transforms + //===--------------------------------------------------------------------===// + // Programmatic controlled lowering of vector.contract only. + OwningRewritePatternList vectorContractLoweringPatterns; + vectorContractLoweringPatterns + .insert( + vector_transforms_options, context); + applyPatternsAndFoldGreedily(func, vectorContractLoweringPatterns); + + // Programmatic controlled lowering of vector.transfer only. + OwningRewritePatternList vectorToLoopsPatterns; + populateVectorToSCFConversionPatterns(vectorToLoopsPatterns, context, + vector_to_scf_options); + applyPatternsAndFoldGreedily(func, vectorToLoopsPatterns); +} + +} // namespace mlir_strategy +} // namespace cpu +} // namespace xla diff --git a/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h new file mode 100644 index 00000000000..3b11b750c47 --- /dev/null +++ b/tensorflow/compiler/xla/service/cpu/mlir_matmul_codegen_strategy.h @@ -0,0 +1,188 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#ifndef MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ +#define MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ + +#include "llvm/ADT/ArrayRef.h" +#include "llvm/ADT/SmallVector.h" +#include "llvm/ADT/StringSwitch.h" +#include "mlir/Conversion/VectorToSCF/VectorToSCF.h" // from @llvm-project +#include "mlir/Dialect/Linalg/Transforms/Transforms.h" // from @llvm-project +#include "mlir/Dialect/Vector/VectorOps.h" // from @llvm-project +#include "mlir/Dialect/Vector/VectorTransforms.h" // from @llvm-project +#include "mlir/Support/LLVM.h" // from @llvm-project + +// TODO(kramerb): Remove this once strategy is in mlir core. + +namespace xla { +namespace cpu { +namespace mlir_strategy { + +/// Abstract Transformation class applied in a sequence that also handles state +/// through markers. +struct Transformation { + virtual ~Transformation() = default; + virtual mlir::OwningRewritePatternList buildRewritePatterns( + mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) = 0; + mlir::linalg::LinalgMarker marker; +}; + +/// Promotion transformation enqueues a particular stage-1 pattern for +/// `Tile`with the appropriate `options`. +// TODO: variadic LinalgOpTypes. +template +struct Tile : public Transformation { + explicit Tile(mlir::linalg::LinalgTilingOptions options) : options(options) {} + + mlir::OwningRewritePatternList buildRewritePatterns( + mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { + mlir::OwningRewritePatternList tiling_patterns; + tiling_patterns.insert>( + context, options, m); + return tiling_patterns; + } + + private: + mlir::linalg::LinalgTilingOptions options; +}; + +/// Promotion transformation enqueues a particular stage-1 pattern for +/// `Promote`with the appropriate `options`. +// TODO: variadic LinalgOpTypes. +template +struct Promote : public Transformation { + explicit Promote(mlir::linalg::LinalgPromotionOptions options) + : options(options) {} + + mlir::OwningRewritePatternList buildRewritePatterns( + mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { + mlir::OwningRewritePatternList promotion_patterns; + promotion_patterns + .insert>(context, + options, m); + return promotion_patterns; + } + + private: + mlir::linalg::LinalgPromotionOptions options; +}; + +/// Vectorization transformation enqueues a particular stage-1 pattern for +/// `LinalgVectorizationPattern` as well as copy to vector +/// transfer rewrite forwarding patterns. +// TODO: variadic LinalgOpTypes. +template +struct Vectorize : public Transformation { + mlir::OwningRewritePatternList buildRewritePatterns( + mlir::MLIRContext *context, mlir::linalg::LinalgMarker m) override { + mlir::OwningRewritePatternList vectorization_patterns; + // FillOp may interfere with forwarding patterns atm, so we bump up the + // priority of LinalgCopyVTRForwardingPattern / + // LinalgCopyVTWForwardingPattern. + vectorization_patterns + .insert>(context, + m); + vectorization_patterns.insert( + context, + /*benefit=*/2); + return vectorization_patterns; + } +}; + +/// Matmul-specific strategy object controls how a linalg.matmul is +/// progressively lowered. +/// The strategy uses a 3-level staged patterns strategy which allows ordering +/// transformations by using the Linalg `applyStagedPatterns` function, where: +/// 1. The first stage consists of the successive `tile`, `promote` and +/// `vectorize` patterns, applied sequentially. +/// 2. The second stage consists of common local canonicalization patterns +/// that are applied eagerly after each stage-1 pattern. +/// 3. the third stage consists of more global transformation, also applied +/// eagerly, after all stage-2 patterns. Such more global transformations +struct MatmulCodegenStrategy { + /// Append a pattern to add a level of tiling for `LinalgOpType` with tiling + /// `options`. + template + MatmulCodegenStrategy &tile(mlir::linalg::LinalgTilingOptions options) { + transformation_sequence.emplace_back(new Tile(options)); + return *this; + } + /// Conditionally append a pattern to add a level of tiling for `LinalgOpType` + /// with tiling `options`. + template + MatmulCodegenStrategy &tileIf(bool b, + mlir::linalg::LinalgTilingOptions options) { + return b ? tile(options) : *this; + } + /// Append a pattern to add a level of promotion for `LinalgOpType` with + /// promotion `options`. + template + MatmulCodegenStrategy &promote(mlir::linalg::LinalgPromotionOptions options) { + transformation_sequence.emplace_back(new Promote(options)); + return *this; + } + /// Conditionally append a pattern to add a level of promotion for + /// `LinalgOpType` with promotion `options`. + template + MatmulCodegenStrategy &promoteIf( + bool b, mlir::linalg::LinalgPromotionOptions options) { + return b ? promote(options) : *this; + return *this; + } + /// Append a pattern to rewrite `LinalgOpType` as a vector operation. + template + MatmulCodegenStrategy &vectorize() { + transformation_sequence.emplace_back(new Vectorize()); + return *this; + } + /// Conditionally append a pattern to rewrite `LinalgOpType` as a vector + /// operation. + template + MatmulCodegenStrategy &vectorizeIf(bool b) { + return b ? vectorize() : *this; + return *this; + } + /// Configure the post staged-patterns late vector transformations. + MatmulCodegenStrategy &setVectorTransformsOptions( + mlir::vector::VectorTransformsOptions options) { + vector_transforms_options = options; + return *this; + } + /// Configure the post staged-patterns late vector.transfer to scf conversion. + MatmulCodegenStrategy &setVectorTransferToSCFOptions( + mlir::VectorTransferToSCFOptions options) { + vector_to_scf_options = options; + return *this; + } + + /// Apply the transformation patterns in sequence with cleanup transformations + /// interleaved. + void transform(mlir::FuncOp func) const; + + private: + mlir::LogicalResult postPatternTransforms(mlir::Operation *func) const; + + mlir::vector::VectorTransformsOptions vector_transforms_options; + mlir::VectorTransferToSCFOptions vector_to_scf_options; + llvm::SmallVector, 4> transformation_sequence; +}; + +} // namespace mlir_strategy +} // namespace cpu +} // namespace xla + +#endif // MLIR_EDGE_BENCHMARKS_STRATEGIES_MATMULCODEGENSTRATEGIES_H_ diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features.h b/tensorflow/compiler/xla/service/cpu/target_machine_features.h index a383b4a4a00..52c26d24fe7 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features.h @@ -52,6 +52,12 @@ class TargetMachineFeatures { virtual int vector_register_num_elements(const llvm::Function& function, PrimitiveType type) const = 0; + // Return the number of vector registers. We need to pass in + // "function" since llvm functions can contain annotations for specializing + // them to specific micro-architectures (though currently XLA does not use + // this functionality). + virtual int vector_register_count(const llvm::Function& function) const = 0; + // Returns the minimum alignment for a buffer of size size_bytes. virtual int64 minimum_alignment_for_allocation(int64 size_bytes) const = 0; @@ -84,6 +90,12 @@ class LLVMTargetMachineFeatures : public TargetMachineFeatures { (primitive_util::BitWidth(type) / 8); } + int vector_register_count(const llvm::Function& function) const override { + llvm::TargetTransformInfo* tti = GetTargetTransformInfoFor(function); + return static_cast(tti->getNumberOfRegisters( + tti->getRegisterClassForType(/*Vector=*/true))); + } + int64 minimum_alignment_for_allocation(int64 size_bytes) const override; private: diff --git a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h index ffc6927cbe1..fbbd0d2233d 100644 --- a/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h +++ b/tensorflow/compiler/xla/service/cpu/target_machine_features_fake.h @@ -44,6 +44,10 @@ class TargetMachineFeaturesWithFakeAlignmentLogic LOG(FATAL) << "Unexpected call to " << __func__; } + int vector_register_count(const llvm::Function& function) const override { + LOG(FATAL) << "Unexpected call to " << __func__; + } + int64 minimum_alignment_for_allocation(int64 size_bytes) const override { return fake_alignment_logic_(size_bytes); } From c164c98a4bd6b5e569786e2cbfaa0016899dd328 Mon Sep 17 00:00:00 2001 From: Adrian Kuegel Date: Wed, 8 Jul 2020 07:34:16 -0700 Subject: [PATCH 37/88] Remove broken (and unused) kernel targets. PiperOrigin-RevId: 320184343 Change-Id: I10d3778dbb0dc86450c6e1ad69cd35b9fdffe8a0 --- tensorflow/core/kernels/mlir_generated/BUILD | 42 ++++++++++---------- 1 file changed, 22 insertions(+), 20 deletions(-) diff --git a/tensorflow/core/kernels/mlir_generated/BUILD b/tensorflow/core/kernels/mlir_generated/BUILD index 4cf58540fcf..fed63ce8433 100644 --- a/tensorflow/core/kernels/mlir_generated/BUILD +++ b/tensorflow/core/kernels/mlir_generated/BUILD @@ -57,27 +57,29 @@ tf_cuda_cc_test( ], ) -gen_kernel_library( - name = "bias_add", - same_shape = "0,2", - tile_size = "16x16", - types = [ - "f16", - "f32", - "f64", - ], -) +# TODO(b/160731748): Re-enable when it works again. +# gen_kernel_library( +# name = "bias_add", +# same_shape = "0,2", +# tile_size = "16x16", +# types = [ +# "f16", +# "f32", +# "f64", +# ], +# ) -gen_kernel_library( - name = "relu", - same_shape = "0,1", - tile_size = "256", - types = [ - "f16", - "f32", - "f64", - ], -) +# TODO(b/160190568): Re-enable when it works again. +# gen_kernel_library( +# name = "relu", +# same_shape = "0,1", +# tile_size = "256", +# types = [ +# "f16", +# "f32", +# "f64", +# ], +# ) gen_kernel_library( name = "tanh", From eec4fbac37a766f23b533e1e067efad1c37aae0c Mon Sep 17 00:00:00 2001 From: Stephan Herhut Date: Wed, 8 Jul 2020 07:39:18 -0700 Subject: [PATCH 38/88] Integrate LLVM at https://github.com/llvm/llvm-project/commit/1ea289681acf PiperOrigin-RevId: 320185037 Change-Id: Ia2149c852d6cf0a07d36c96aace3bf1c3fcb12b9 --- tensorflow/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 26a8e8271a8..e1608359623 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -710,8 +710,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "f54d0e36be6a4d5dab67244e85b8664282dcf5d1" - LLVM_SHA256 = "df115acb6b5b1a5e2f49819a0b6ae8cf47ccecb61a10d75f85a67198ff06420e" + LLVM_COMMIT = "1ea289681acf622ceda783c8fda2f16754b7c933" + LLVM_SHA256 = "c8772fc7c664cf06af2b0683e1f17555ff56132c78937c76f77d5308428957e4" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), From e1ff0e7e812d2197480c9af18ace99739993f835 Mon Sep 17 00:00:00 2001 From: Bixia Zheng Date: Wed, 8 Jul 2020 08:26:55 -0700 Subject: [PATCH 39/88] [TF:TRT] Fix the TF-TRT bridge to respect the device assignment information in the graph nodes. Previously, the TF-TRT bridge may put the graph nodes that are intended to run on a non-GPU device to run with TensorRT. It may also put two graph nodes that are intended to run on two specific but different GPUs to one TensorRT engine which will make the two nodes run on the same GPUs. This change modifies the bridge to only put nodes that are not intended to run on non-GPU devices to run with TensorRT and to only group nodes with compatible device assignments into a clust for a TensorRT engine. Add tests to test cluster formation in the presence of device assignment. Modify an existing test to test conversion with device assignment. PiperOrigin-RevId: 320191854 Change-Id: I76aea67abc7634d73ba07bccec52ac25df65735b --- tensorflow/compiler/tf2tensorrt/BUILD | 2 + .../tf2tensorrt/convert/convert_graph.cc | 60 ++++------- .../compiler/tf2tensorrt/convert/utils.cc | 39 +++++++ .../compiler/tf2tensorrt/convert/utils.h | 21 ++++ .../compiler/tf2tensorrt/segment/segment.cc | 102 ++++++++---------- .../tf2tensorrt/segment/segment_test.cc | 63 +++++++++++ .../compiler/tf2tensorrt/segment/union_find.h | 27 ++++- .../compiler/tensorrt/trt_convert_test.py | 31 ++++-- 8 files changed, 236 insertions(+), 109 deletions(-) diff --git a/tensorflow/compiler/tf2tensorrt/BUILD b/tensorflow/compiler/tf2tensorrt/BUILD index 368cb5af2ed..0718bd8cd65 100644 --- a/tensorflow/compiler/tf2tensorrt/BUILD +++ b/tensorflow/compiler/tf2tensorrt/BUILD @@ -501,6 +501,7 @@ cc_library( copts = tf_copts(), deps = [ ":common_utils", + ":utils", "//tensorflow/core:graph", "//tensorflow/core:lib", "//tensorflow/core:lib_internal", @@ -570,6 +571,7 @@ cc_library( deps = [ "@com_google_absl//absl/algorithm:container", "//tensorflow/core:framework", + "//tensorflow/core:graph", "//tensorflow/core:lib_proto_parsing", "//tensorflow/core:lib", ] + if_tensorrt([":tensorrt_lib"]), diff --git a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc index 5429aaf3362..c9210a1a1e7 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/convert_graph.cc @@ -130,7 +130,9 @@ Status GetEngineInfo(const Graph* g, EngineInfo* info) { std::vector subgraph_nodes; // Topologically sorted nodes. std::set added_const_nodes; // Used to prevent double insertion. - std::set segment_devices; + // The device assignment accumulated from the compatible device assignments + // for the nodes in the segment. + DeviceNameUtils::ParsedName segment_device; // Map from src_node_name+port to the unique port numbers of the TRT op, where // the src_node_name is the name of the source node of the input/output @@ -144,36 +146,17 @@ Status GetEngineInfo(const Graph* g, const Node* node = *it; if (segment_nodes.count(node) == 0) continue; - std::string device_name; - if (!node->requested_device().empty()) { - device_name = node->requested_device(); - } else if (node->has_assigned_device_name()) { - // It appears that nodes will not have assigned devices at this point in - // execution. - device_name = node->assigned_device_name(); - } else { - VLOG(2) << "Node " << node->name() - << " neither have requested device nor assigned device"; - } - - if (!device_name.empty()) { - // If device is set, it means device placement may have been done before, - // so we need to assign a device for the TRTEngineOp if the assigned - // device is a GPU device. - DeviceNameUtils::ParsedName parsed_name; - const bool parse_succeeded = - DeviceNameUtils::ParseFullName(device_name, &parsed_name); - if (!parse_succeeded) { - VLOG(1) << "Failed to parse " - << (node->requested_device().empty() ? "assigned" : "requested") - << " device " << device_name << " of node " << node->name(); - } else if (parsed_name.type != "GPU") { - VLOG(1) << "Node " << node->name() - << " was assigned to a non-GPU device " << device_name; - } else { - segment_devices.insert(device_name); - } + absl::optional new_segment_device = + MergeIfCompatible(segment_device, GetDeviceName(node)); + if (!new_segment_device.has_value()) { + // The segmenter should guarantee that nodes in the same segment have + // compatible device assignments. + return errors::Internal( + "segment nodes have incompatible device assignments: ", + DeviceNameUtils::ParsedNameToString(segment_device), " vs ", + GetDeviceName(node), " to node ", node->name()); } + segment_device = *new_segment_device; subgraph_nodes.push_back(node); const int node_id = node->id(); @@ -273,13 +256,16 @@ Status GetEngineInfo(const Graph* g, info->engine_name = StrCat(scope_name, info->engine_name); VLOG(1) << "Converted TensorRT candidate segment '" << info->engine_name << "' to a GraphDef"; - if (segment_devices.size() == 1) { - info->device = *segment_devices.begin(); - } else if (segment_devices.size() > 1) { - LOG_WARNING_WITH_PREFIX - << "Detected multiple (" << segment_devices.size() - << ") devices for the segment. Picking first one to continue."; - info->device = *segment_devices.begin(); + if (segment_device.has_type) { + // If the accumulated device assignment for the segment has a device type, + // the segmenter guarantees the device type is GPU. Use the device + // assignment in this case. + if (segment_device.type != "GPU") { + return errors::Internal( + "segment device is not GPU: ", + DeviceNameUtils::ParsedNameToString(segment_device)); + } + info->device = DeviceNameUtils::ParsedNameToString(segment_device); } else { TfGpuId tf_gpu_id; PlatformGpuId platform_gpu_id; diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.cc b/tensorflow/compiler/tf2tensorrt/convert/utils.cc index a4b64ec0dc5..a69960005fc 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.cc +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.cc @@ -271,5 +271,44 @@ string GetLoadedTensorRTVersion() { return absl::StrCat(major, ".", minor, ".", patch); } +absl::string_view GetDeviceName(const Node* node) { + if (node->has_assigned_device_name()) { + return node->assigned_device_name(); + } + return node->requested_device(); +} + +absl::optional GetDeviceParsedName( + const Node* node) { + absl::string_view device_name = GetDeviceName(node); + DeviceNameUtils::ParsedName parsed_name; + if (!DeviceNameUtils::ParseFullName(device_name, &parsed_name)) { + return absl::nullopt; + } + return parsed_name; +} + +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, + const DeviceNameUtils::ParsedName& b) { + DeviceNameUtils::ParsedName merged_name = a; + if (!DeviceNameUtils::MergeDevNames(&merged_name, b, + /*allow_soft_placement=*/false) + .ok()) { + return absl::nullopt; + } + return merged_name; +} + +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, absl::string_view b) { + DeviceNameUtils::ParsedName b_parsed_name; + if (!DeviceNameUtils::ParseFullName(b, &b_parsed_name)) { + return absl::nullopt; + } + + return MergeIfCompatible(a, b_parsed_name); +} + } // namespace tensorrt } // namespace tensorflow diff --git a/tensorflow/compiler/tf2tensorrt/convert/utils.h b/tensorflow/compiler/tf2tensorrt/convert/utils.h index 775616ff7aa..a0505c3f922 100644 --- a/tensorflow/compiler/tf2tensorrt/convert/utils.h +++ b/tensorflow/compiler/tf2tensorrt/convert/utils.h @@ -21,6 +21,7 @@ limitations under the License. #include "absl/algorithm/container.h" #include "tensorflow/core/framework/tensor_shape.h" +#include "tensorflow/core/graph/graph.h" #include "tensorflow/core/lib/core/status.h" #if GOOGLE_CUDA && GOOGLE_TENSORRT @@ -133,6 +134,26 @@ bool AreShapesCompatible(const std::vector& actual_shapes, // input bindings, because the number of total input bindings equals the number // of profiles times the number of engine inputs. int GetNumberOfEngineInputs(const nvinfer1::ICudaEngine* engine); + +// Returns the string representation for the assigned device or the requested +// device of the given node. +absl::string_view GetDeviceName(const Node* node); + +// Returns the ParsedName representation for the assigned device or the +// requested device string of the given node. If the device string is invalid, +// returns absl::nullopt. +absl::optional GetDeviceParsedName( + const Node* node); + +// If the given two device assignments as compatible, returns the merge of the +// two assignments. Otherwise, returns absl::nullopt. +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, const DeviceNameUtils::ParsedName& b); +// Similar to the above, except that the second device assignment is represented +// by a string_view. +absl::optional MergeIfCompatible( + const DeviceNameUtils::ParsedName& a, absl::string_view b); + #endif // GOOGLE_CUDA && GOOGLE_TENSORRT } // namespace tensorrt diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment.cc b/tensorflow/compiler/tf2tensorrt/segment/segment.cc index bca2aba4650..e7820ca41fe 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment.cc @@ -23,6 +23,7 @@ limitations under the License. #include "absl/strings/str_cat.h" #include "absl/strings/str_format.h" #include "tensorflow/compiler/tf2tensorrt/common/utils.h" +#include "tensorflow/compiler/tf2tensorrt/convert/utils.h" #include "tensorflow/compiler/tf2tensorrt/segment/union_find.h" #include "tensorflow/core/common_runtime/graph_constructor.h" #include "tensorflow/core/graph/algorithm.h" @@ -664,11 +665,15 @@ ClusterBatchSize GetClusterBatchSizeForNode( void AddSegmentForNode(const grappler::GraphProperties* graph_properties, std::vector>* segments, - SimpleNode* node, bool use_implicit_batch) { + SimpleNode* node, + const DeviceNameUtils::ParsedName& device_name, + bool use_implicit_batch) { segments->emplace_back( - node, GetClusterBatchSizeForNode( - graph_properties, node == nullptr ? nullptr : node->tf_node(), - use_implicit_batch)); + node, + GetClusterBatchSizeForNode(graph_properties, + node == nullptr ? nullptr : node->tf_node(), + use_implicit_batch), + device_name); } } // namespace @@ -734,7 +739,13 @@ Status SegmentGraph(const Graph* tf_graph, num_unsupported_ops++; node = nullptr; }; - if (options.exclude_node_list.count(node->name()) != 0) { + absl::optional device_name = + GetDeviceParsedName(node->tf_node()); + // GetDeviceParseName capitalizes the device type. + if (!device_name.has_value() || + (device_name->has_type && device_name->type != "GPU")) { + exclude_node("node can't be placed on GPU"); + } else if (options.exclude_node_list.count(node->name()) != 0) { exclude_node("excluded by segmenter option"); } else if (options.use_implicit_batch && !OperationCanBeTranslatedToImplicitBatch(graph_properties, @@ -763,7 +774,7 @@ Status SegmentGraph(const Graph* tf_graph, << "(Op name: " << node->name(); } } - AddSegmentForNode(graph_properties, &node_segments, node, + AddSegmentForNode(graph_properties, &node_segments, node, *device_name, options.use_implicit_batch); } string msg = StrCat( @@ -809,6 +820,8 @@ Status SegmentGraph(const Graph* tf_graph, // contracting an output edge may unblock new edges for contracting. ClusterBatchSize expected_batch_size = node_segments[node->id()].BatchSize(); + DeviceNameUtils::ParsedName expected_device_name = + node_segments[node->id()].DeviceName(); VLOG(3) << "batch size " << expected_batch_size; while (true) { std::set contract_edges; @@ -821,26 +834,39 @@ Status SegmentGraph(const Graph* tf_graph, VLOG(3) << "... ... Control Edge, Skipping"; continue; } + UnionFind* out_cluster = + &node_segments[out_edge->dst()->id()]; // Out node must be a TRT candidate. - if (node_segments[out_edge->dst()->id()].Value() == nullptr) { + if (out_cluster->Value() == nullptr) { VLOG(3) << "... ... not a TRT candidate"; continue; } // Out node must have compatible batch size. - ClusterBatchSize out_batch_size = - node_segments[out_edge->dst()->id()].BatchSize(); + ClusterBatchSize out_batch_size = out_cluster->BatchSize(); ClusterBatchSize merged_batch_size = expected_batch_size; if (!merged_batch_size.MergeIfCompatible(out_batch_size)) { - VLOG(3) << "... ... incompatible batch size " + VLOG(3) << "... ... incompatible batch sizes " << expected_batch_size.ToString() << " " << out_batch_size.ToString(); continue; } + + const DeviceNameUtils::ParsedName& out_device_name = + out_cluster->DeviceName(); + absl::optional merged_device_name = + MergeIfCompatible(expected_device_name, out_device_name); + if (!merged_device_name.has_value()) { + VLOG(3) << "... ... incompatible device names " + << expected_device_name << " " << out_device_name; + continue; + } + if (CanContractEdge(out_edge, graph)) { VLOG(3) << "... ... can contract. new batch size " << merged_batch_size.ToString(); contract_edges.insert(out_edge); expected_batch_size = merged_batch_size; + expected_device_name = *merged_device_name; } else { VLOG(3) << "... ... cannot contract, would form cycle"; } @@ -872,12 +898,14 @@ Status SegmentGraph(const Graph* tf_graph, graph->RemoveEdge(r); } } - ClusterBatchSize actual_batch_size = - node_segments[node->id()].BatchSize(); - if (expected_batch_size != actual_batch_size) { + if (expected_batch_size != node_segments[node->id()].BatchSize()) { return errors::Internal( "expected batch size is not the same as the actual batch size"); } + if (expected_device_name != node_segments[node->id()].DeviceName()) { + return errors::Internal( + "expected device name is not the same as the actual device name"); + } } } @@ -888,34 +916,9 @@ Status SegmentGraph(const Graph* tf_graph, // the segment tree) to the segment nodes set. std::map> sg_map; - // A map from the segment identifier (currently the name of the root node of - // the segment tree) to the device names that the nodes in the segment are - // assigned to. - // - // TODO(aaroey): nodes assigned to different devices should not be merged, - // fix this. - std::unordered_map> device_maps; - for (auto& u : node_segments) { if ((u.Value() != nullptr) && (u.ParentValue() != nullptr)) { sg_map[u.ParentValue()->name()].insert(u.Value()->tf_node()); - auto tf_node = u.Value()->tf_node(); - // has_assigned_device_name() is expected to return true - // when called from optimization pass. However, since graph - // is converted back and forth between graph and graphdef, - // assigned devices demoted to requested devices. If the graph - // is passed directly to this module, assigned devices will be set. - if (tf_node->has_assigned_device_name()) { - device_maps[u.ParentValue()->name()].insert( - tf_node->assigned_device_name()); - } else if (!tf_node->requested_device().empty()) { - device_maps[u.ParentValue()->name()].insert( - tf_node->requested_device()); - } else { - VLOG(2) << "Node " << tf_node->name() - << " has no device assigned requested device is: " - << tf_node->requested_device(); - } } } @@ -1034,30 +1037,9 @@ Status SegmentGraph(const Graph* tf_graph, continue; } - const auto& dev_itr = device_maps.find(segment_root); - if (dev_itr == device_maps.end() || dev_itr->second.empty()) { - VLOG(1) << "No device assigned to segment " << segments->size(); - } else if (dev_itr->second.size() > 1) { - string s = StrCat("Segment ", segments->size(), - " has multiple devices attached: "); - for (const auto& dev : dev_itr->second) { - StrAppend(&s, dev, ", "); - } - LOG_WARNING_WITH_PREFIX << s; - } - segments->emplace_back(segment_nodes); } - if (VLOG_IS_ON(1)) { - for (const auto& d : device_maps) { - string s("Segment "); - StrAppend(&s, ": '", d.first, "' "); - for (const auto& dd : d.second) { - StrAppend(&s, dd, ", "); - } - VLOG(1) << "Devices " << s; - } - } + return Status::OK(); } diff --git a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc index f3bc5bfbee6..bf277328fe7 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc +++ b/tensorflow/compiler/tf2tensorrt/segment/segment_test.cc @@ -179,6 +179,69 @@ TEST_F(SegmentTest, Simple) { RunTest(&g, all_adds, all_adds, without_add3, {all_adds}); } +TEST_F(SegmentTest, WithDeviceAssignments) { + // feed + // // \\ + // add0 add1 + // | \ / + // | add2 + // | / \\ + // add3 add4 + // \ / + // + Scope s = Scope::NewRootScope(); + auto feed = ops::Placeholder(s.WithOpName("feed"), DT_FLOAT); + auto add0 = ops::Add(s.WithOpName("add0"), feed, feed); + auto add1 = ops::Add(s.WithOpName("add1"), feed, feed); + auto add2 = ops::Add(s.WithOpName("add2"), add0, add1); + auto add3 = ops::Add(s.WithOpName("add3"), add0, add2); + auto add4 = ops::Add(s.WithOpName("add4"), add2, add2); + + const std::set all_adds = {"add0", "add1", "add2", "add3", "add4"}; + DisableImplicitBatchMode(); + + { + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {all_adds}); + } + + { + // Assigning add1 to CPU to exclude it from the cluster. + add1.node()->set_assigned_device_name("/device:CPU:0"); + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {all_adds - "add1"}); + add1.node()->set_assigned_device_name(""); + } + + { + // Assigning operations add3 and add4 to another GPU to exclude the + // operation from the cluster. + constexpr char kGpu0[] = "/device:GPU:0"; + add0.node()->set_assigned_device_name(kGpu0); + add1.node()->set_assigned_device_name(kGpu0); + add2.node()->set_assigned_device_name(kGpu0); + constexpr char kGpu1[] = "/device:GPU:1"; + add3.node()->set_assigned_device_name(kGpu1); + add4.node()->set_assigned_device_name(kGpu1); + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {{"add0", "add1", "add2"}}); + } + + { + // Assigning the operations to two compatibile GPU devices resulting in + // one cluster with all operations. + constexpr char kGpuAny[] = "/device:GPU:*"; + add3.node()->set_assigned_device_name(kGpuAny); + add4.node()->set_assigned_device_name(kGpuAny); + Graph g(OpRegistry::Global()); + TF_EXPECT_OK(s.ToGraph(&g)); + RunTest(&g, all_adds, all_adds, all_adds, {all_adds}); + } +} + TEST_F(SegmentTest, AvoidCycle) { // feed // // \\ diff --git a/tensorflow/compiler/tf2tensorrt/segment/union_find.h b/tensorflow/compiler/tf2tensorrt/segment/union_find.h index b53615ec019..b91f5771ce5 100644 --- a/tensorflow/compiler/tf2tensorrt/segment/union_find.h +++ b/tensorflow/compiler/tf2tensorrt/segment/union_find.h @@ -149,9 +149,11 @@ template class UnionFind { public: UnionFind() : size_(1), parent_(nullptr) {} - explicit UnionFind(const T& v, ClusterBatchSize batch_size) + UnionFind(const T& v, ClusterBatchSize batch_size, + const DeviceNameUtils::ParsedName& device_name) : size_(1), cluster_batch_size_(batch_size), + cluster_device_name_(device_name), parent_(nullptr), value_(v) {} @@ -159,10 +161,16 @@ class UnionFind { // this object to the root of the cluster. int Size() { return FindRoot()->size_; } - // Returns the batch size of the cluster and compress the path from this + // Returns the batch size of the cluster and compresses the path from this // object to the root object. ClusterBatchSize BatchSize() { return FindRoot()->cluster_batch_size_; } + // Returns the device name of the cluster and compresses the path from this + // object to the root object. + const DeviceNameUtils::ParsedName& DeviceName() { + return FindRoot()->cluster_device_name_; + } + // Merges this cluster with 'other'. This cluster's size_ is updated to // the size of the merged cluster; the size_ of 'other' becomes inaccessible // as only the size_ of the root object is accessible. @@ -181,6 +189,7 @@ class UnionFind { int size_; ClusterBatchSize cluster_batch_size_; + DeviceNameUtils::ParsedName cluster_device_name_; UnionFind* parent_; T value_; }; @@ -192,12 +201,20 @@ Status UnionFind::Merge(UnionFind* other) { if (a == b) return Status::OK(); ClusterBatchSize batch_size = a->cluster_batch_size_; - bool merged = batch_size.MergeIfCompatible(other->cluster_batch_size_); - if (!merged) { - return errors::Internal("trying to merge incompatible cluster."); + if (!batch_size.MergeIfCompatible(other->cluster_batch_size_)) { + return errors::Internal( + "trying to merge clusters with incompatible batch sizes."); + } + + absl::optional device_name = + MergeIfCompatible(a->cluster_device_name_, other->cluster_device_name_); + if (!device_name.has_value()) { + return errors::Internal( + "trying to merge clusters with incompatible device assignment."); } a->cluster_batch_size_ = batch_size; + a->cluster_device_name_ = *device_name; b->parent_ = a; a->size_ += b->size_; return Status::OK(); diff --git a/tensorflow/python/compiler/tensorrt/trt_convert_test.py b/tensorflow/python/compiler/tensorrt/trt_convert_test.py index 8c5c925f026..9052fc2b6ed 100644 --- a/tensorflow/python/compiler/tensorrt/trt_convert_test.py +++ b/tensorflow/python/compiler/tensorrt/trt_convert_test.py @@ -353,18 +353,35 @@ class TrtConvertTest(test_util.TensorFlowTestCase, parameterized.TestCase): self._MayRemoveGraphSequenceNumber(node.name): node.op for node in graph_def.node } - self.assertEqual( - { - "input1": "Placeholder", - "input2": "Placeholder", - "TRTEngineOp_0": "TRTEngineOp", - "output": "Identity" - }, node_name_to_op) + if device is not None and device.startswith("/CPU:"): + self.assertEqual( + { + "add": "AddV2", + "add/ReadVariableOp": "Const", + "add_1": "AddV2", + "add_2": "AddV2", + "input1": "Placeholder", + "input2": "Placeholder", + "mul": "Mul", + "output": "Identity" + }, node_name_to_op) + else: + self.assertEqual( + { + "input1": "Placeholder", + "input2": "Placeholder", + "TRTEngineOp_0": "TRTEngineOp", + "output": "Identity" + }, node_name_to_op) if need_calibration: trt_engine_nodes = [ node for node in graph_def.node if node.op == "TRTEngineOp" ] + if device is not None and device.startswith("/CPU:"): + self.assertEmpty(trt_engine_nodes) + return + self.assertNotEmpty(trt_engine_nodes) for node in trt_engine_nodes: self.assertTrue(len(node.attr["calibration_data"].s)) From 7d38b67e5c526d49a7fb61bf7ec13edb7547d44f Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 08:51:34 -0700 Subject: [PATCH 40/88] Qualify uses of std::string PiperOrigin-RevId: 320196263 Change-Id: I895b7295e698f7005b74ffc42bc04eb4a9a87b75 --- tensorflow/core/grappler/costs/op_context.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/grappler/costs/op_context.h b/tensorflow/core/grappler/costs/op_context.h index 6391de4a91e..49474410585 100644 --- a/tensorflow/core/grappler/costs/op_context.h +++ b/tensorflow/core/grappler/costs/op_context.h @@ -25,8 +25,8 @@ namespace grappler { // A structure to keep the context of op execution, including its shape, // execution context, and other relevant information. struct OpContext { - string name; - string device_name; + std::string name; + std::string device_name; OpInfo op_info; const FunctionDefLibrary* function_library; // Not owned. From 67b47fd1c7388ec3f080b41294ae15909f1ee198 Mon Sep 17 00:00:00 2001 From: Smit Hinsu Date: Wed, 8 Jul 2020 08:57:06 -0700 Subject: [PATCH 41/88] Disable MLIR bridge for NMS image ops test MLIR bridge doesn't support tf.NonMaxSuppressionV4 legalization that is conditionally generated by non_max_suppression_padded function. PiperOrigin-RevId: 320197235 Change-Id: If7242133254680b366771ced50de074ed6180563 --- tensorflow/compiler/tests/image_ops_test.py | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/tensorflow/compiler/tests/image_ops_test.py b/tensorflow/compiler/tests/image_ops_test.py index 9590688fda7..326c3ec4929 100644 --- a/tensorflow/compiler/tests/image_ops_test.py +++ b/tensorflow/compiler/tests/image_ops_test.py @@ -30,6 +30,7 @@ from six.moves import xrange # pylint: disable=redefined-builtin from tensorflow.compiler.tests import xla_test from tensorflow.python.framework import dtypes from tensorflow.python.framework import ops +from tensorflow.python.framework import test_util from tensorflow.python.ops import array_ops from tensorflow.python.ops import gen_image_ops from tensorflow.python.ops import image_ops @@ -774,6 +775,7 @@ class ResizeBilinearNonAlignCornersTest(xla_test.XLATestCase): class NonMaxSuppressionTest(xla_test.XLATestCase): + @test_util.disable_mlir_bridge("%1") def testNMS128From1024(self): num_boxes = 1024 boxes_np = np.random.normal(50, 10, (num_boxes, 4)).astype("f4") @@ -808,6 +810,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(indices_tf.size, max_output_size) + @test_util.disable_mlir_bridge("%1") def testNMS3From6Boxes(self): # Three boxes are selected based on IOU. boxes_data = [[0, 0, 1, 1], [0, 0.1, 1, 1.1], [0, -0.1, 1, 0.9], @@ -849,6 +852,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 3) self.assertAllClose(indices_tf[:num_valid], [3, 0, 5]) + @test_util.disable_mlir_bridge("%1") def testNMS3Then2WithScoreThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -891,6 +895,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 2) self.assertAllClose(indices_tf[:num_valid], [3, 0]) + @test_util.disable_mlir_bridge("%1") def testNMS3Then1WithScoreMaxThresh(self): # Three boxes are selected based on IOU. # One is filtered out by score threshold. @@ -934,6 +939,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): self.assertEqual(num_valid, 1) self.assertAllClose(indices_tf[:num_valid], [3]) + @test_util.disable_mlir_bridge("%1") def testSelectFromContinuousOverLap(self): # Tests that a suppressed box does not itself suppress other boxes. @@ -978,6 +984,7 @@ class NonMaxSuppressionTest(xla_test.XLATestCase): class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): + @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1015,6 +1022,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1048,6 +1056,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[0, 1, 2], [0, 1, 3]], indices_output) self.assertAllEqual([3, 3], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6Max3(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1078,6 +1087,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2], indices_output) self.assertAllEqual(3, num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSSingleFrom6NoPad(self): boxes_data = [[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]] @@ -1107,6 +1117,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([0, 1, 2, 4, 5], indices_output) self.assertAllEqual(5, num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSBatchDimsFrom6Max3(self): boxes_data = [[[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1140,6 +1151,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([[[0, 1, 2], [0, 1, 3]]], indices_output) self.assertAllEqual([[3, 3]], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1175,6 +1187,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSUnsortedInputFrom6(self): boxes_data = [[[0, 2, 1, 2], [3, 3, 4, 4], [0, 0, 1, 1], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8]], @@ -1211,6 +1224,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSNoncanonicalizedInputFrom6(self): boxes_data = [[[1, 0, 0, 1], [4, 3, 3, 4], [1, 0.4, 0, 1.4], [1, 0.6, 0, 1.6], [1, 0.8, 0, 1.8], [1, 2, 0, 2]], @@ -1248,6 +1262,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): indices_output) self.assertAllEqual([5, 4], num_valid_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSScoreThresholdCanInputsFrom6Max3(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], @@ -1283,6 +1298,7 @@ class BatchedNonMaxSuppressionCorrectnessTest(xla_test.XLATestCase): self.assertAllEqual([3, 2], num_valid_output) self.assertAllEqual([[0, 1, 2], [0, 1, invalid_index]], indices_output) + @test_util.disable_mlir_bridge("%1") def testBatchedNMSFrom6DynamicInput(self): boxes_data = [[[0, 0, 1, 1], [3, 3, 4, 4], [0, 0.4, 1, 1.4], [0, 0.6, 1, 1.6], [0, 0.8, 1, 1.8], [0, 2, 1, 2]], From 67c317635bf1a3c05f567c7d060b523b6199a981 Mon Sep 17 00:00:00 2001 From: Allen Lavoie Date: Wed, 8 Jul 2020 09:09:33 -0700 Subject: [PATCH 42/88] Make the "untracked resource" error less awful Searches through gc.get_referrers for a few hops. Seems to find variables and tables, which will cover most cases. PiperOrigin-RevId: 320199554 Change-Id: I69e68ab86d064f689b2248b25fa486fb99096652 --- tensorflow/python/saved_model/save.py | 22 +++++++++++++--- tensorflow/python/saved_model/save_test.py | 30 ++++++++++++++++++++++ 2 files changed, 49 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/saved_model/save.py b/tensorflow/python/saved_model/save.py index 4220146b6c8..27b05b01e34 100644 --- a/tensorflow/python/saved_model/save.py +++ b/tensorflow/python/saved_model/save.py @@ -20,6 +20,7 @@ from __future__ import print_function import collections import functools +import gc import os from tensorflow.core.framework import versions_pb2 @@ -387,12 +388,27 @@ def _map_captures_to_created_tensors(original_captures, resource_map): for exterior, interior in original_captures: mapped_resource = resource_map.get(exterior, None) if mapped_resource is None: + trackable_referrers = [] + # Try to figure out where the resource came from by iterating over objects + # which reference it. This is slow and doesn't help us figure out how to + # match it to other objects when loading the SavedModel as a checkpoint, + # so we can't continue saving. But we can at least tell the user what + # needs attaching. + for primary_referrer in gc.get_referrers(exterior): + if isinstance(primary_referrer, base.Trackable): + trackable_referrers.append(primary_referrer) + for secondary_referrer in gc.get_referrers(primary_referrer): + if isinstance(secondary_referrer, base.Trackable): + trackable_referrers.append(secondary_referrer) raise AssertionError( - ("Tried to export a function which references untracked object {}." + ("Tried to export a function which references untracked resource {}." "TensorFlow objects (e.g. tf.Variable) captured by functions must " "be tracked by assigning them to an attribute of a tracked object " - "or assigned to an attribute of the main object directly." - ).format(interior)) + "or assigned to an attribute of the main object directly.\n\n" + "Trackable Python objects referring to this tensor " + "(from gc.get_referrers, limited to two hops):\n{}" + ).format(interior, + "\n".join([repr(obj) for obj in trackable_referrers]))) export_captures.append(mapped_resource) return export_captures diff --git a/tensorflow/python/saved_model/save_test.py b/tensorflow/python/saved_model/save_test.py index 0755f11ff71..7be54df09ec 100644 --- a/tensorflow/python/saved_model/save_test.py +++ b/tensorflow/python/saved_model/save_test.py @@ -179,6 +179,18 @@ class SaveTest(test.TestCase): with self.assertRaisesRegex(ValueError, "ERROR MSG"): save.save(root, os.path.join(self.get_temp_dir(), "saved_model")) + def test_untracked_variable_useful_message(self): + root = module.Module() + v = variables.Variable(1., name="some_unique_name") + + @def_function.function(input_signature=[]) + def f(): + return v.read_value() + + root.f = f + with self.assertRaisesRegex(AssertionError, "some_unique_name"): + save.save(root, os.path.join(self.get_temp_dir(), "saved_model")) + def test_version_information_included(self): root = tracking.AutoTrackable() save_dir = os.path.join(self.get_temp_dir(), "saved_model") @@ -598,6 +610,24 @@ class AssetTests(test.TestCase): {"output_0": [2, 1]}, _import_and_infer(second_dir, {"keys": ["gamma", "beta"]})) + def test_untracked_table_useful_message(self): + root = module.Module() + initializer = lookup_ops.TextFileInitializer( + self._vocab_path, + key_dtype=dtypes.string, + key_index=lookup_ops.TextFileIndex.WHOLE_LINE, + value_dtype=dtypes.int64, + value_index=lookup_ops.TextFileIndex.LINE_NUMBER) + table = lookup_ops.HashTable( + initializer, default_value=-1) + root.table_user = def_function.function( + table.lookup, + input_signature=[tensor_spec.TensorSpec(None, dtypes.string)]) + root.table_user(constant_op.constant("gamma")) + save_dir = os.path.join(self.get_temp_dir(), "saved_model") + with self.assertRaisesRegexp(AssertionError, "HashTable"): + save.save(root, save_dir) + def test_unused_asset(self): root = tracking.AutoTrackable() root.f = def_function.function( From b4c6669fbea0ab80abd6c54fe69eb6879a0db09b Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 09:13:37 -0700 Subject: [PATCH 43/88] Integrate LLVM at https://github.com/llvm/llvm-project/commit/24b62f28c5da PiperOrigin-RevId: 320200195 Change-Id: Ifbdb7d418c117129e54a379be3e16b522ee920fe --- tensorflow/workspace.bzl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index e1608359623..57d5462a099 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -710,8 +710,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): ) # Check out LLVM and MLIR from llvm-project. - LLVM_COMMIT = "1ea289681acf622ceda783c8fda2f16754b7c933" - LLVM_SHA256 = "c8772fc7c664cf06af2b0683e1f17555ff56132c78937c76f77d5308428957e4" + LLVM_COMMIT = "24b62f28c5daa293a2602712e1eba82cb59f3a6f" + LLVM_SHA256 = "499dd931c05f63d4a8d155d423f5a1a89ace86f4d03b22a541e29ae3e8d13b3b" LLVM_URLS = [ "https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), "https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT), From 86fa7c3fc791d16e1dc62d9765be168fa6c1bf64 Mon Sep 17 00:00:00 2001 From: Scott Zhu Date: Wed, 8 Jul 2020 09:26:21 -0700 Subject: [PATCH 44/88] Fork the checkpointing_test and move the keras related test to keras/distribute. PiperOrigin-RevId: 320202511 Change-Id: I0c57c90847915bd702e4fa204653a76f7cd65c38 --- tensorflow/python/distribute/BUILD | 3 - .../python/distribute/checkpointing_test.py | 63 ------------ tensorflow/python/keras/distribute/BUILD | 16 +++ .../keras/distribute/checkpointing_test.py | 98 +++++++++++++++++++ 4 files changed, 114 insertions(+), 66 deletions(-) create mode 100644 tensorflow/python/keras/distribute/checkpointing_test.py diff --git a/tensorflow/python/distribute/BUILD b/tensorflow/python/distribute/BUILD index 2fd1fc93cd3..a083ac512af 100644 --- a/tensorflow/python/distribute/BUILD +++ b/tensorflow/python/distribute/BUILD @@ -947,10 +947,7 @@ distribute_py_test( deps = [ ":combinations", ":strategy_combinations", - ":tpu_strategy", - "//tensorflow/compiler/tests:xla_test", "//tensorflow/python/eager:test", - "//tensorflow/python/keras", "//tensorflow/python/training/tracking:util", ], ) diff --git a/tensorflow/python/distribute/checkpointing_test.py b/tensorflow/python/distribute/checkpointing_test.py index ecb678a4cf3..a4be193284e 100644 --- a/tensorflow/python/distribute/checkpointing_test.py +++ b/tensorflow/python/distribute/checkpointing_test.py @@ -23,12 +23,8 @@ from absl.testing import parameterized from tensorflow.python.distribute import combinations from tensorflow.python.distribute import strategy_combinations -from tensorflow.python.eager import backprop -from tensorflow.python.eager import def_function from tensorflow.python.eager import test -from tensorflow.python.keras.optimizer_v2 import adam from tensorflow.python.ops import array_ops -from tensorflow.python.ops import random_ops from tensorflow.python.ops import variables as variables_lib from tensorflow.python.training.tracking import util as trackable_utils @@ -71,65 +67,6 @@ class TrainingCheckpointTests(test.TestCase, parameterized.TestCase): restore_checkpoint.v = v self.assertAllClose(array_ops.zeros(variable_shape), v) - @combinations.generate( - combinations.combine( - distribution=[ - strategy_combinations.mirrored_strategy_with_one_cpu, - strategy_combinations.mirrored_strategy_with_gpu_and_cpu, - strategy_combinations.tpu_strategy, - strategy_combinations.tpu_strategy_packed_var, - strategy_combinations.central_storage_strategy_with_two_gpus, - ], - mode=["eager"])) - def testCheckpointRestoreOptimizerSlots(self, distribution): - def state(): - with distribution.scope(): - v = variables_lib.Variable(random_ops.random_normal([])) - opt = adam.Adam(0.001) - - @def_function.function - def step(): - def f(): - with backprop.GradientTape() as tape: - loss = v + v - gradients = tape.gradient(loss, [v]) - opt.apply_gradients(zip(gradients, [v])) - - distribution.run(f) - - return v, opt, step - - def checkpoint(): - v, opt, step = state() - step() - - # Save random weights into checkpoint. - checkpoint = trackable_utils.Checkpoint(v=v, opt=opt) - prefix = os.path.join(self.get_temp_dir(), "ckpt") - with self.test_session(): - save_path = checkpoint.save(prefix) - return save_path - - save_path = checkpoint() - - v, opt, step = state() - checkpoint = trackable_utils.Checkpoint(v=v, opt=opt) - # Restore from the checkpoint inside a distribution.scope(). - with self.test_session(): - with distribution.scope(): - checkpoint.restore(save_path) - step() - slot = opt.get_slot(v, "m") - self.assertEqual(v._distribute_strategy, slot._distribute_strategy) - - v, opt, step = state() - checkpoint = trackable_utils.Checkpoint(v=v, opt=opt) - # Restore from the checkpoint outside a distribution.scope(). - with self.test_session(): - with self.assertRaisesRegex( - ValueError, "optimizer slot variable under the scope"): - checkpoint.restore(save_path) - if __name__ == "__main__": test.main() diff --git a/tensorflow/python/keras/distribute/BUILD b/tensorflow/python/keras/distribute/BUILD index f549839f604..dee4788784a 100644 --- a/tensorflow/python/keras/distribute/BUILD +++ b/tensorflow/python/keras/distribute/BUILD @@ -77,6 +77,22 @@ cuda_py_test( ], ) +distribute_py_test( + name = "checkpointing_test", + srcs = ["checkpointing_test.py"], + main = "checkpointing_test.py", + tags = [ + "multi_and_single_gpu", + ], + deps = [ + "//tensorflow/python/distribute:combinations", + "//tensorflow/python/distribute:strategy_combinations", + "//tensorflow/python/eager:test", + "//tensorflow/python/keras/optimizer_v2", + "//tensorflow/python/training/tracking:util", + ], +) + cuda_py_test( name = "collective_all_reduce_strategy_test", srcs = ["collective_all_reduce_strategy_test.py"], diff --git a/tensorflow/python/keras/distribute/checkpointing_test.py b/tensorflow/python/keras/distribute/checkpointing_test.py new file mode 100644 index 00000000000..77c335fe46d --- /dev/null +++ b/tensorflow/python/keras/distribute/checkpointing_test.py @@ -0,0 +1,98 @@ +# Copyright 2018 The TensorFlow Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# ============================================================================== + +from __future__ import absolute_import +from __future__ import division +from __future__ import print_function + +import os + +from absl.testing import parameterized + +from tensorflow.python.distribute import combinations +from tensorflow.python.distribute import strategy_combinations +from tensorflow.python.eager import backprop +from tensorflow.python.eager import def_function +from tensorflow.python.eager import test +from tensorflow.python.keras.optimizer_v2 import adam +from tensorflow.python.ops import random_ops +from tensorflow.python.ops import variables as variables_lib +from tensorflow.python.training.tracking import util as trackable_utils + + +class TrainingCheckpointTests(test.TestCase, parameterized.TestCase): + + @combinations.generate( + combinations.combine( + distribution=[ + strategy_combinations.mirrored_strategy_with_one_cpu, + strategy_combinations.mirrored_strategy_with_gpu_and_cpu, + strategy_combinations.tpu_strategy, + strategy_combinations.tpu_strategy_packed_var, + strategy_combinations.central_storage_strategy_with_two_gpus, + ], + mode=["eager"])) + def testCheckpointRestoreOptimizerSlots(self, distribution): + def state(): + with distribution.scope(): + v = variables_lib.Variable(random_ops.random_normal([])) + opt = adam.Adam(0.001) + + @def_function.function + def step(): + def f(): + with backprop.GradientTape() as tape: + loss = v + v + gradients = tape.gradient(loss, [v]) + opt.apply_gradients(zip(gradients, [v])) + + distribution.run(f) + + return v, opt, step + + def checkpoint(): + v, opt, step = state() + step() + + # Save random weights into checkpoint. + checkpoint = trackable_utils.Checkpoint(v=v, opt=opt) + prefix = os.path.join(self.get_temp_dir(), "ckpt") + with self.test_session(): + save_path = checkpoint.save(prefix) + return save_path + + save_path = checkpoint() + + v, opt, step = state() + checkpoint = trackable_utils.Checkpoint(v=v, opt=opt) + # Restore from the checkpoint inside a distribution.scope(). + with self.test_session(): + with distribution.scope(): + checkpoint.restore(save_path) + step() + slot = opt.get_slot(v, "m") + self.assertEqual(v._distribute_strategy, slot._distribute_strategy) + + v, opt, step = state() + checkpoint = trackable_utils.Checkpoint(v=v, opt=opt) + # Restore from the checkpoint outside a distribution.scope(). + with self.test_session(): + with self.assertRaisesRegex( + ValueError, "optimizer slot variable under the scope"): + checkpoint.restore(save_path) + + +if __name__ == "__main__": + test.main() From c0b72c2a4226119527927efa8cf9248a0cf99bb6 Mon Sep 17 00:00:00 2001 From: Dero Gharibian Date: Wed, 8 Jul 2020 09:36:57 -0700 Subject: [PATCH 45/88] Fix msan false positive for //third_party/tensorflow/go:tensorflow_test PiperOrigin-RevId: 320204496 Change-Id: I069dbb1a18eadde1f6eaa0d0d7e5c02d1799e07c --- tensorflow/go/tensor.go | 11 +++++------ 1 file changed, 5 insertions(+), 6 deletions(-) diff --git a/tensorflow/go/tensor.go b/tensorflow/go/tensor.go index 56594a73270..9221d35274c 100644 --- a/tensorflow/go/tensor.go +++ b/tensorflow/go/tensor.go @@ -21,11 +21,9 @@ package tensorflow #include #include "tensorflow/c/c_api.h" -TF_TString toNewTString(_GoString_ gstr) { - TF_TString tstr; - TF_TString_Init(&tstr); - TF_TString_Copy(&tstr, _GoStringPtr(gstr), _GoStringLen(gstr)); - return tstr; +void toNewTString(_GoString_ gstr, TF_TString *tstr) { + TF_TString_Init(tstr); + TF_TString_Copy(tstr, _GoStringPtr(gstr), _GoStringLen(gstr)); } */ import "C" @@ -448,7 +446,8 @@ func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) err } } else if v.Kind() == reflect.String { s := v.Interface().(string) - tstr := C.toNewTString(s) + var tstr C.TF_TString + C.toNewTString(s, &tstr) ptr := unsafe.Pointer(&tstr) return copyPtr(w, ptr, C.sizeof_TF_TString) } else if v.Kind() != reflect.Array { From 4e1aa305a15132e703a207c7628a7614e416dc1d Mon Sep 17 00:00:00 2001 From: Jacques Pienaar Date: Wed, 8 Jul 2020 09:45:06 -0700 Subject: [PATCH 46/88] Internal change PiperOrigin-RevId: 320205862 Change-Id: Iad999a1a85d3d193258017d8a042973d57163492 --- third_party/mlir/tblgen.bzl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/third_party/mlir/tblgen.bzl b/third_party/mlir/tblgen.bzl index 90bb2436ea8..10dde932da5 100644 --- a/third_party/mlir/tblgen.bzl +++ b/third_party/mlir/tblgen.bzl @@ -57,7 +57,8 @@ def gentbl(name, tblgen, td_file, tbl_outs, td_srcs = [], td_includes = [], td_r "$(location %s)" % td_file, "-I$(GENDIR)", ] + td_includes_cmd - rule_suffix = "_".join(opts.replace("-", "_").replace("=", "_").split(" ")) + first_opt = opts.split(" ", 1)[0] + rule_suffix = "_{}_{}".format(first_opt.replace("-", "_").replace("=", "_"), str(hash(opts))) # Rule to generate code using generated shell script. native.genrule( From e2f8269f17ada26a8533f716b00b9a2ec7c9c6f4 Mon Sep 17 00:00:00 2001 From: Rahul Joshi Date: Wed, 8 Jul 2020 09:46:21 -0700 Subject: [PATCH 47/88] [MLIR][NFC] Adopt variadic isa<> on MLIR Types and Attributes - Also adopt variadic llvm::isa<> in more places PiperOrigin-RevId: 320206113 Change-Id: Ia03a1503f699fb6be6dff02e90b6630d6d894b19 --- tensorflow/compiler/mlir/lite/flatbuffer_import.cc | 6 ++---- .../mlir/lite/quantization/quantization_driver.cc | 3 +-- .../mlir/lite/quantization/quantization_utils.h | 3 +-- .../compiler/mlir/lite/transforms/optimize.cc | 4 ++-- .../tensorflow/analysis/side_effect_analysis.cc | 5 ++--- .../compiler/mlir/tensorflow/ir/tf_executor.cc | 2 +- tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc | 3 +-- .../compiler/mlir/tensorflow/ir/tf_saved_model.cc | 2 +- tensorflow/compiler/mlir/tensorflow/ir/tf_types.h | 3 +-- .../tensorflow/transforms/resource_op_lifting.cc | 2 +- .../mlir/tensorflow/transforms/shape_inference.cc | 14 ++++++-------- .../transforms/stack_ops_decomposition.cc | 2 +- .../transforms/tensor_array_ops_decomposition.cc | 2 +- .../tensorflow/transforms/tpu_cluster_formation.cc | 3 +-- .../mlir/tensorflow/translate/export_graphdef.cc | 5 ++--- 15 files changed, 24 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc index faa6bc36824..cf637df3b99 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_import.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_import.cc @@ -443,8 +443,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, if (auto float_type = elem_type.dyn_cast()) { TF_ASSIGN_OR_RETURN(value, ConvertFloatBuffer(shaped_type, float_type, buffer)); - } else if (elem_type.isa() || - elem_type.isa()) { + } else if (elem_type.isa()) { TF_ASSIGN_OR_RETURN(value, ConvertIntBuffer(shaped_type, elem_type, buffer)); } else if (elem_type.isa()) { @@ -456,8 +455,7 @@ StatusOr BuildConstOp(const tflite::TensorT& tensor, refs.push_back({ref.data(), ref.size()}); value = mlir::DenseStringElementsAttr::get(shaped_type, refs); - } else if (elem_type.isa() || - elem_type.isa()) { + } else if (elem_type.isa()) { auto dialect = elem_type.getContext()->getRegisteredDialect("tf"); tensorflow::TensorProto repr = ConvertTfliteConstTensor(tensor, buffer); std::string mangled = tensorflow::mangling_util::MangleTensor(repr); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc index 1b409bc939b..0c9ccf1a979 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_driver.cc @@ -694,8 +694,7 @@ void QuantizationDriver::SetupAllStates() { fn_.walk([&](Operation *op) { if (op->isKnownTerminator() || op->hasTrait() || - llvm::isa(op) || - llvm::isa(op)) + llvm::isa(op)) return; work_list_.push_back(op); diff --git a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h index 35c930281d0..4ced43014f5 100644 --- a/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h +++ b/tensorflow/compiler/mlir/lite/quantization/quantization_utils.h @@ -386,8 +386,7 @@ struct FoldTrivalRequantizeOp : public OpRewritePattern { Operation* def = pre_quantized.getDefiningOp(); if (!def) return failure(); - if (llvm::isa(def) || - llvm::isa(def) || + if (llvm::isa(def) || def->hasTrait()) { return failure(); } diff --git a/tensorflow/compiler/mlir/lite/transforms/optimize.cc b/tensorflow/compiler/mlir/lite/transforms/optimize.cc index bd3c217605b..d26a4906420 100644 --- a/tensorflow/compiler/mlir/lite/transforms/optimize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/optimize.cc @@ -560,7 +560,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { return failure(); ShapedType filter_type = filter_cst.getType(); - if (llvm::isa(binary_op) || llvm::isa(binary_op)) { + if (llvm::isa(binary_op)) { auto padding = fc_op.template getAttrOfType("padding"); if (padding && padding.getValue() != "VALID") return failure(); @@ -606,7 +606,7 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern { rewriter.create(fc_op.getLoc(), new_bias_type, new_bias); fc_op.setOperand(0, binary_op->getOperand(0)); fc_op.setOperand(2, new_bias_op); - } else if (llvm::isa(binary_op) || llvm::isa(binary_op)) { + } else if (llvm::isa(binary_op)) { // The fusion of mul/div is actually applying the following // transformation: // w * (x ' c) + b => (w ' c) x + b diff --git a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc index e4de66b59e2..be203e0397e 100644 --- a/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc +++ b/tensorflow/compiler/mlir/tensorflow/analysis/side_effect_analysis.cc @@ -168,8 +168,7 @@ void ResourceAliasAnalysis::AnalyzeFunction(FuncOp func_op) { var_handle.resource(), GetOrCreateIdForVarHandle(var_handle, &next_unique_id, &var_handle_name_id_map)); - } else if (llvm::isa(op) || - llvm::isa(op)) { + } else if (llvm::isa(op)) { for (auto operand_and_result : llvm::zip(op->getOperands(), op->getResults())) { forward_input_to_output(std::get<0>(operand_and_result), @@ -333,7 +332,7 @@ bool OpIsDeclaration(Operation* op, const ResourceAliasAnalysis& alias_analysis) { // TODO(yuanzx): Add other types of resources. return llvm::isa(op) || - ((llvm::isa(op) || llvm::isa(op)) && + (llvm::isa(op) && !FindAccessedResources(op, alias_analysis).empty()); } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc index 1e66eee06bb..1b1d5ba6f3b 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_executor.cc @@ -71,7 +71,7 @@ struct TensorFlowExecutorInlinerInterface : public DialectInlinerInterface { // Allow inlining into tf.island regions if the incoming region has a single // block. return llvm::isa(dest->getParentOp()) && - std::next(src->begin()) == src->end(); + llvm::hasSingleElement(*src); } }; diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index b2af3e45f3f..eb027748d28 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -1168,8 +1168,7 @@ void ConstOp::build(OpBuilder &builder, OperationState &result, ShapedType type; if (auto elem_attr = value.dyn_cast()) { return ConstOp::build(builder, result, elem_attr); - } else if (value.isa() || value.isa() || - value.isa()) { + } else if (value.isa()) { // All TensorFlow types must be tensor types. In the build() method, // we want to provide more flexibility by allowing attributes of scalar // types. But we need to wrap it up with ElementsAttr to construct diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc index f205a7ea2e8..edfc7feefd5 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.cc @@ -356,7 +356,7 @@ LogicalResult VerifyExportedFunc(FuncOp func) { LogicalResult TensorFlowSavedModelDialect::verifyOperationAttribute( Operation *op, NamedAttribute named_attr) { if (named_attr.first == "tf_saved_model.exported_names") { - if (!isa(op) && !isa(op)) { + if (!isa(op)) { return op->emitError() << "'tf_saved_model.exported_names' must be on a " "'func' or 'tf_saved_model.global_tensor' op"; } diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h index a250ca1af8c..f352bc0eb47 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_types.h @@ -90,8 +90,7 @@ class TensorFlowType : public Type { // Returns true if the specified type is a valid TensorFlow element type. static inline bool IsValidTFElementType(Type type) { - return type.isa() || type.isa() || - type.isa() || type.isa(); + return type.isa(); } // Returns true if this is a valid TensorFlow tensor type. diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc index 2d30bbd1b93..a05ebcb8191 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/resource_op_lifting.cc @@ -375,7 +375,7 @@ LogicalResult FindResourceArgUseInfo( info.data_type = assign.value().getType(); continue; } - if (isa(user) || isa(user)) { + if (isa(user)) { // Stacks will be handled by a separate pass. do_not_touch = true; break; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc index 2afc1c2d7b6..8e689f7f7b2 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/shape_inference.cc @@ -205,9 +205,9 @@ GetSubtypes(Type type) { // Returns whether type can be further refined. bool CanBeRefined(Type type) { auto shape_type = type.dyn_cast(); - return shape_type && (!shape_type.hasStaticShape() || - shape_type.getElementType().isa() || - shape_type.getElementType().isa()); + return shape_type && + (!shape_type.hasStaticShape() || + shape_type.getElementType().isa()); } // Infers the shape from a (Stateful)PartionedCall operation by looking up the @@ -712,8 +712,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { // The shape function of these ops sometimes does not propagate subtypes // (handle shapes) for resource and variant types. We use a simple passthrough // to make sure they are preserved in the output. - if (isa(op) || isa(op) || - isa(op) || isa(op)) { + if (isa(op)) { return RefineTypeForPassThroughOperands(op, op->getOperands(), op->getResults()); } @@ -729,7 +728,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { // Handle call operations by looking up callee and infering return shape as // needed. - if (isa(op) || isa(op)) + if (isa(op)) return InferShapeForCall(op); // tf.Cast are only inferred if they have at least one user in the TF dialect @@ -889,8 +888,7 @@ bool ShapeInference::InferShapeForSingleOperation(Operation* op) { }; auto new_element_type = shaped_type.getElementType(); // Populate the handle shapes for a resource/variant. - if (new_element_type.isa() || - new_element_type.isa()) { + if (new_element_type.isa()) { auto handle_shapes_types = c.output_handle_shapes_and_types(output); if (handle_shapes_types) { SmallVector subtypes; diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc index 734a7d04a86..5e095a311ee 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/stack_ops_decomposition.cc @@ -488,7 +488,7 @@ LogicalResult DecomposeStackOpsInternal( llvm::StringMap* decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { - if (llvm::isa(&op) || llvm::isa(&op)) { + if (llvm::isa(&op)) { // Removes identity nodes in the block. The device computation does not // need such nodes to carry information. op.replaceAllUsesWith(op.getOperands()); diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc index cbd24f8a815..9c659a95078 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tensor_array_ops_decomposition.cc @@ -809,7 +809,7 @@ LogicalResult DecomposeTensorArrayOps( llvm::StringMap* decomposed_partitioned_call_callees) { for (auto& op : llvm::make_early_inc_range(block->getOperations())) { - if (llvm::isa(&op) || llvm::isa(&op)) { + if (llvm::isa(&op)) { op.replaceAllUsesWith(op.getOperands()); op.erase(); } else if (auto ta = llvm::dyn_cast(&op)) { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc index c7557374fee..bab29404b61 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_cluster_formation.cc @@ -495,8 +495,7 @@ void TPUClusterFormation::runOnFunction() { // Remove TPUReplicatedInput and TPUReplicatedOutput nodes. auto remove_result = getFunction().walk([&](Operation* op) { - if (!llvm::isa(op) && - !llvm::isa(op)) + if (!llvm::isa(op)) return WalkResult::advance(); // Forward operand to result. When `num_replicas` attribute is 1, no diff --git a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc index b6fad8f5987..7983dfe0065 100644 --- a/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc +++ b/tensorflow/compiler/mlir/tensorflow/translate/export_graphdef.cc @@ -576,9 +576,8 @@ StatusOr> Exporter::Convert( // Adds nodes for operations. for (Operation& inst : graph_op.GetBody()) { for (auto type : inst.getResultTypes()) - if (!type.isa() && - !type.isa() && - !type.isa()) + if (!type.isa()) return errors::InvalidArgument( "Values must be of tensor type, TensorFlow control type, or " "TensorFlow token type. Found ", From 1f8b8e79ab339a3f2fc045e9d4ada8ae1c78fae1 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 09:48:08 -0700 Subject: [PATCH 48/88] Qualify uses of std::string PiperOrigin-RevId: 320206487 Change-Id: Id004f3aca96c4bca86326c95c95f0e928ace85ab --- tensorflow/lite/toco/toco_types.h | 6 +----- 1 file changed, 1 insertion(+), 5 deletions(-) diff --git a/tensorflow/lite/toco/toco_types.h b/tensorflow/lite/toco/toco_types.h index 76dd1b0348d..43e708ce04f 100644 --- a/tensorflow/lite/toco/toco_types.h +++ b/tensorflow/lite/toco/toco_types.h @@ -21,12 +21,8 @@ limitations under the License. #include "tensorflow/core/platform/types.h" namespace toco { -#ifdef PLATFORM_GOOGLE -using ::string; -#else -using std::string; -#endif +using std::string; using tensorflow::int16; using tensorflow::int32; using tensorflow::int64; From 1942ccc2003e41502a1a76c51b633a5d676399d6 Mon Sep 17 00:00:00 2001 From: Amit Patankar Date: Wed, 8 Jul 2020 09:53:12 -0700 Subject: [PATCH 49/88] Update syntax in pip_smoke_test. PiperOrigin-RevId: 320207489 Change-Id: Ie99eb22de56c961ebe42c6f3368542cf81e58bcd --- tensorflow/tools/pip_package/pip_smoke_test.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/tools/pip_package/pip_smoke_test.py b/tensorflow/tools/pip_package/pip_smoke_test.py index cb0a587b5eb..40d2cff56b4 100644 --- a/tensorflow/tools/pip_package/pip_smoke_test.py +++ b/tensorflow/tools/pip_package/pip_smoke_test.py @@ -32,7 +32,7 @@ PIP_PACKAGE_QUERY_EXPRESSION = ( # List of file paths containing BUILD files that should not be included for the # pip smoke test. -BUILD_BLACKLIST = [ +BUILD_DENYLIST = [ "tensorflow/lite", "tensorflow/compiler/mlir/lite", "tensorflow/python/kernel_tests/signal", @@ -46,7 +46,7 @@ def GetBuild(dir_base): items = [] for root, _, files in os.walk(dir_base): for name in files: - if (name == "BUILD" and not any(x in root for x in BUILD_BLACKLIST)): + if (name == "BUILD" and not any(x in root for x in BUILD_DENYLIST)): items.append("//" + root + ":all") return items @@ -70,9 +70,9 @@ def BuildPyTestDependencies(): PYTHON_TARGETS, PY_TEST_QUERY_EXPRESSION = BuildPyTestDependencies() -# TODO(amitpatankar): Clean up blacklist. +# TODO(amitpatankar): Clean up denylist. # List of dependencies that should not included in the pip package. -DEPENDENCY_BLACKLIST = [ +DEPENDENCY_DENYLIST = [ "//tensorflow/python:extra_py_tests_deps", "//tensorflow/cc/saved_model:saved_model_half_plus_two", "//tensorflow:no_tensorflow_py_deps", @@ -142,7 +142,7 @@ def main(): ] ignored_files_count = 0 - blacklisted_dependencies_count = len(DEPENDENCY_BLACKLIST) + denylisted_dependencies_count = len(DEPENDENCY_DENYLIST) # Compare dependencies for dependency in tf_py_test_dependencies_list: if dependency and dependency.startswith("//tensorflow"): @@ -152,14 +152,14 @@ def main(): ignore = True ignored_files_count += 1 - # Check if the dependency is in the pip package, the dependency blacklist, + # Check if the dependency is in the pip package, the dependency denylist, # or should be ignored because of its file extension. if not (ignore or dependency in pip_package_dependencies_list or - dependency in DEPENDENCY_BLACKLIST): + dependency in DEPENDENCY_DENYLIST): missing_dependencies.append(dependency) print("Ignored files count: %d" % ignored_files_count) - print("Blacklisted dependencies count: %d" % blacklisted_dependencies_count) + print("Denylisted dependencies count: %d" % denylisted_dependencies_count) if missing_dependencies: print("Missing the following dependencies from pip_packages:") for missing_dependency in missing_dependencies: @@ -174,7 +174,7 @@ def main(): raise RuntimeError(""" One or more added test dependencies are not in the pip package. If these test dependencies need to be in TensorFlow pip package, please add them to //tensorflow/tools/pip_package/BUILD. -Else either blacklist the dependencies in //tensorflow/tools/pip_package/pip_smoke_test.py +Else either denylist the dependencies in //tensorflow/tools/pip_package/pip_smoke_test.py or add no_pip tag to the test.""") else: From b1cdfe3b9bb526477f2ae19b1497809dc6be508d Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 09:54:56 -0700 Subject: [PATCH 50/88] Add SparseTensor support to assert_type. PiperOrigin-RevId: 320207860 Change-Id: I869efd91e2f5911e6f3bc184f06685245b70f1f0 --- .../python/kernel_tests/check_ops_test.py | 23 +++++++++++++++++++ tensorflow/python/ops/check_ops.py | 12 ++++++---- 2 files changed, 30 insertions(+), 5 deletions(-) diff --git a/tensorflow/python/kernel_tests/check_ops_test.py b/tensorflow/python/kernel_tests/check_ops_test.py index 376b0058927..acc5af03097 100644 --- a/tensorflow/python/kernel_tests/check_ops_test.py +++ b/tensorflow/python/kernel_tests/check_ops_test.py @@ -1528,12 +1528,35 @@ class AssertTypeTest(test.TestCase): out = array_ops.identity(integers) self.evaluate(out) + @test_util.run_in_graph_and_eager_modes + def test_sparsetensor_doesnt_raise_when_correct_type(self): + sparse_float = sparse_tensor.SparseTensor( + constant_op.constant([[111], [232]], dtypes.int64), + constant_op.constant([23.4, -43.2], dtypes.float32), + constant_op.constant([500], dtypes.int64)) + + with ops.control_dependencies( + [check_ops.assert_type(sparse_float, dtypes.float32)]): + out = sparse_tensor.SparseTensor(sparse_float.indices, + array_ops.identity(sparse_float.values), + sparse_float.dense_shape) + self.evaluate(out) + @test_util.run_in_graph_and_eager_modes def test_raises_when_wrong_type(self): floats = constant_op.constant([1.0, 2.0], dtype=dtypes.float16) with self.assertRaisesRegex(TypeError, "must be of type.*float32"): check_ops.assert_type(floats, dtypes.float32) + @test_util.run_in_graph_and_eager_modes + def test_sparsetensor_raises_when_wrong_type(self): + sparse_float16 = sparse_tensor.SparseTensor( + constant_op.constant([[111], [232]], dtypes.int64), + constant_op.constant([23.4, -43.2], dtypes.float16), + constant_op.constant([500], dtypes.int64)) + with self.assertRaisesRegexp(TypeError, "must be of type.*float32"): + check_ops.assert_type(sparse_float16, dtypes.float32) + class AssertShapesTest(test.TestCase): diff --git a/tensorflow/python/ops/check_ops.py b/tensorflow/python/ops/check_ops.py index a9c8ca37724..9bc638ac5a2 100644 --- a/tensorflow/python/ops/check_ops.py +++ b/tensorflow/python/ops/check_ops.py @@ -1518,7 +1518,7 @@ def assert_type_v2(tensor, tf_type, message=None, name=None): This can always be checked statically, so this method returns nothing. Args: - tensor: A `Tensor`. + tensor: A `Tensor` or `SparseTensor`. tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, etc). message: A string to prefix to the default message. @@ -1537,7 +1537,7 @@ def assert_type(tensor, tf_type, message=None, name=None): """Statically asserts that the given `Tensor` is of the specified type. Args: - tensor: A `Tensor`. + tensor: A `Tensor` or `SparseTensor`. tf_type: A tensorflow type (`dtypes.float32`, `tf.int64`, `dtypes.bool`, etc). message: A string to prefix to the default message. @@ -1551,13 +1551,15 @@ def assert_type(tensor, tf_type, message=None, name=None): """ message = message or '' with ops.name_scope(name, 'assert_type', [tensor]): - tensor = ops.convert_to_tensor(tensor, name='tensor') + if not isinstance(tensor, sparse_tensor.SparseTensor): + tensor = ops.convert_to_tensor(tensor, name='tensor') if tensor.dtype != tf_type: if context.executing_eagerly(): raise TypeError('%s tensor must be of type %s' % (message, tf_type)) else: - raise TypeError('%s %s must be of type %s' % (message, tensor.name, - tf_type)) + raise TypeError( + '%s %s must be of type %s' % + (message, tensor.name if hasattr(tensor, 'name') else '', tf_type)) return control_flow_ops.no_op('statically_determined_correct_type') From a85beef408a3f5dcd93158e88a0aade6cb70f8c5 Mon Sep 17 00:00:00 2001 From: Kuangyuan Chen Date: Wed, 8 Jul 2020 09:56:54 -0700 Subject: [PATCH 51/88] Fix a race condition in attr_builder where the AttrTypeMap for the same op_name might be created multiple times, causing memory leak. PiperOrigin-RevId: 320208354 Change-Id: I36b7f077ae6645160198c55d0441d482141aa707 --- tensorflow/core/common_runtime/eager/attr_builder.cc | 12 +++++++++++- 1 file changed, 11 insertions(+), 1 deletion(-) diff --git a/tensorflow/core/common_runtime/eager/attr_builder.cc b/tensorflow/core/common_runtime/eager/attr_builder.cc index 2d85b2f764d..fd79a82e4b2 100644 --- a/tensorflow/core/common_runtime/eager/attr_builder.cc +++ b/tensorflow/core/common_runtime/eager/attr_builder.cc @@ -71,7 +71,15 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); if (*out != nullptr) return Status::OK(); } + mutex_lock l(g_op_name_to_attr_type_map_lock); + + // Check the existence of AttrTypeMap for op_name again because another thread + // may insert this map after the tf_shared_lock is released but before the + // mutex_lock is acquired. + *out = gtl::FindPtrOrNull(*OpNameToAttrTypeMap(), op_name); + if (*out != nullptr) return Status::OK(); + const OpDef* op_def = nullptr; Status s = OpDefForOp(op_name, &op_def); if (errors::IsNotFound(s)) { @@ -121,7 +129,9 @@ Status AttrTypeMapForOp(const char* op_name, const AttrTypeMap** out, gtl::InsertIfNotPresent(m.get(), attr.name(), t); } *out = m.get(); - (*OpNameToAttrTypeMap())[op_name] = m.release(); + auto r = OpNameToAttrTypeMap()->emplace(op_name, m.release()); + DCHECK(r.second) << "AttrTypeMap already exists for " << op_name; + return Status::OK(); } From 7eab1f3bfe796f453b4549a0923ef954c822671c Mon Sep 17 00:00:00 2001 From: Karmel Allison Date: Wed, 8 Jul 2020 10:04:34 -0700 Subject: [PATCH 52/88] Replace instances of "whitelist" with "allowlist" where possible. See Google Developer guidelines at https://developers.google.com/style/word-list#blacklist for more information. PiperOrigin-RevId: 320210110 Change-Id: I480d2b1c80d7d77fdd071b7642011758988f18c0 --- RELEASE.md | 3 + SECURITY.md | 2 +- .../compiler/jit/mark_for_compilation_pass.cc | 40 +-- .../compiler/jit/mark_for_compilation_pass.h | 6 +- .../jit/mark_for_compilation_pass_test.cc | 18 +- .../mlir/lite/common/tfl_pass_config.h | 6 +- .../compiler/mlir/lite/flatbuffer_export.cc | 12 +- .../mlir/lite/tests/prepare-quantize.mlir | 2 +- .../mlir/lite/tests/trim-functions-tf.mlir | 2 +- .../compiler/mlir/lite/transforms/passes.h | 2 +- .../mlir/lite/transforms/prepare_quantize.cc | 10 +- .../mlir/lite/transforms/trim_functions_tf.cc | 46 ++-- .../xla/tests/legalize-tf-with-tf2xla.mlir | 4 +- .../xla/transforms/legalize_tf_with_tf2xla.cc | 6 +- tensorflow/compiler/tf2xla/xla_op_registry.cc | 58 ++--- tensorflow/compiler/tf2xla/xla_op_registry.h | 12 +- .../common_runtime/graph_execution_state.cc | 4 +- .../framework/dataset_stateful_op_whitelist.h | 28 +-- .../core/grappler/costs/graph_properties.cc | 8 +- .../grappler/costs/graph_properties_test.cc | 2 +- .../optimizers/auto_mixed_precision.cc | 224 ++++++++--------- .../optimizers/auto_mixed_precision_lists.h | 22 +- .../optimizers/auto_mixed_precision_test.cc | 227 +++++++++--------- .../grappler/optimizers/constant_folding.cc | 12 +- .../grappler/optimizers/constant_folding.h | 2 +- .../core/kernels/data/captured_function.cc | 12 +- tensorflow/go/op/wrappers.go | 6 +- tensorflow/lite/delegates/flex/BUILD | 4 +- ...ed_flex_ops.cc => allowlisted_flex_ops.cc} | 12 +- ...sted_flex_ops.h => allowlisted_flex_ops.h} | 16 +- tensorflow/lite/delegates/hexagon/utils.cc | 2 +- .../delegates/nnapi/acceleration_test_list.cc | 4 +- .../delegates/nnapi/acceleration_test_util.h | 2 +- .../lite/delegates/nnapi/nnapi_delegate.cc | 2 +- .../lite/experimental/acceleration/README.md | 4 +- .../configuration/configuration.proto | 2 +- tensorflow/lite/g3doc/guide/ops_select.md | 4 +- .../lite/gpu/CompatibilityListTest.java | 4 +- .../lite/kernels/acceleration_test_util.h | 2 +- .../acceleration_test_util_internal_test.cc | 42 ++-- .../micro/tools/make/generate_keil_project.py | 2 +- tensorflow/lite/toco/tflite/export_test.cc | 2 +- tensorflow/lite/toco/tflite/operator.cc | 8 +- .../tasks/coco_object_detection/README.md | 2 +- .../preprocess_coco_minival.py | 38 +-- .../lite/tools/optimize/quantize_model.h | 2 +- tensorflow/python/__init__.py | 2 +- .../python/autograph/converters/call_trees.py | 2 +- .../autograph/core/converter_testing.py | 12 +- .../autograph/g3doc/reference/functions.md | 20 +- tensorflow/python/autograph/impl/api.py | 20 +- tensorflow/python/autograph/impl/api_test.py | 36 +-- .../python/autograph/impl/conversion.py | 70 +++--- .../python/autograph/impl/conversion_test.py | 53 ++-- .../pyct/common_transformers/anf_test.py | 6 +- .../python/autograph/pyct/error_utils.py | 17 +- tensorflow/python/data/ops/dataset_ops.py | 4 +- tensorflow/python/debug/cli/analyzer_cli.py | 4 +- .../python/debug/cli/analyzer_cli_test.py | 6 +- .../debug/lib/check_numerics_callback.py | 2 +- tensorflow/python/debug/lib/debug_utils.py | 40 +-- .../python/debug/lib/debug_utils_test.py | 24 +- .../debug/lib/dist_session_debug_grpc_test.py | 8 +- .../python/debug/lib/grpc_large_data_test.py | 23 +- .../debug/lib/session_debug_grpc_test.py | 4 +- tensorflow/python/debug/lib/source_utils.py | 17 +- .../python/debug/lib/source_utils_test.py | 6 +- .../debug/wrappers/dumping_wrapper_test.py | 15 +- tensorflow/python/debug/wrappers/framework.py | 97 ++++---- tensorflow/python/debug/wrappers/hooks.py | 24 +- .../debug/wrappers/local_cli_wrapper.py | 6 +- tensorflow/python/distribute/mirrored_run.py | 2 +- tensorflow/python/eager/function.py | 6 +- .../python/framework/auto_control_deps.py | 6 +- .../python/framework/convert_to_constants.py | 24 +- tensorflow/python/framework/func_graph.py | 18 +- tensorflow/python/framework/function.py | 24 +- tensorflow/python/framework/function_test.py | 6 +- .../python/framework/graph_util_impl.py | 2 +- tensorflow/python/framework/importer_test.py | 2 +- tensorflow/python/framework/python_op_gen.cc | 2 +- .../python/framework/python_op_gen_main.cc | 12 +- .../grappler/auto_mixed_precision_test.py | 4 +- .../python/keras/engine/training_utils.py | 8 +- tensorflow/python/keras/engine/training_v1.py | 2 +- .../python/keras/layers/wrappers_test.py | 2 +- .../keras/optimizer_v2/optimizer_v2_test.py | 6 +- .../keras/preprocessing/dataset_utils.py | 4 +- .../keras/preprocessing/image_dataset.py | 4 +- tensorflow/python/ops/while_v2.py | 14 +- .../ops/while_v2_indexed_slices_rewriter.py | 8 +- .../selective_registration_header_lib.py | 4 +- tensorflow/python/tpu/feature_column_test.py | 2 +- .../training/experimental/mixed_precision.py | 4 +- tensorflow/python/util/all_util.py | 2 +- tensorflow/tools/ci_build/ci_sanity.sh | 22 +- tensorflow/tools/test/check_futures_test.py | 6 +- 97 files changed, 831 insertions(+), 813 deletions(-) rename tensorflow/lite/delegates/flex/{whitelisted_flex_ops.cc => allowlisted_flex_ops.cc} (98%) rename tensorflow/lite/delegates/flex/{whitelisted_flex_ops.h => allowlisted_flex_ops.h} (65%) diff --git a/RELEASE.md b/RELEASE.md index 5c05f2a4285..22edbcd1f41 100644 --- a/RELEASE.md +++ b/RELEASE.md @@ -50,6 +50,9 @@ * Tracing and Debugging: * * Other: + * We have replaced uses of "whitelist" with "allowlist" where possible. + Please see https://developers.google.com/style/word-list#blacklist for more + context. * ## Thanks to our Contributors diff --git a/SECURITY.md b/SECURITY.md index f3a6c148b2e..6c722766b3a 100644 --- a/SECURITY.md +++ b/SECURITY.md @@ -44,7 +44,7 @@ Even if the untrusted party only supplies the serialized computation graph (in form of a `GraphDef`, `SavedModel`, or equivalent on-disk format), the set of computation primitives available to TensorFlow is powerful enough that you should assume that the TensorFlow process effectively executes arbitrary -code. One common solution is to whitelist only a few safe Ops. While this is +code. One common solution is to allow only a few safe Ops. While this is possible in theory, we still recommend you sandbox the execution. It depends on the computation graph whether a user provided checkpoint is safe. diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.cc b/tensorflow/compiler/jit/mark_for_compilation_pass.cc index dc5df94e963..55ff57a04c5 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.cc @@ -1096,33 +1096,33 @@ StatusOr IsIdentityDrivingConstsInLoop(Node* node) { return true; } -absl::flat_hash_set GetOrCreateWhitelist() { - absl::flat_hash_map>* whitelist_table = - tensorflow::GetWhitelistTable(); +absl::flat_hash_set GetOrCreateAllowlist() { + absl::flat_hash_map>* allowlist_table = + tensorflow::GetAllowlistTable(); MarkForCompilationPassFlags* flags = GetMarkForCompilationPassFlags(); - absl::flat_hash_set whitelist; + absl::flat_hash_set allowlist; for (auto s : absl::StrSplit(flags->tf_xla_ops_to_cluster, ',')) { if (s == "FUSIBLE") { - for (auto pair : *whitelist_table) { - whitelist.insert(pair.second.begin(), pair.second.end()); + for (auto pair : *allowlist_table) { + allowlist.insert(pair.second.begin(), pair.second.end()); } - } else if (whitelist_table->contains(s)) { - auto v = whitelist_table->at(s); - whitelist.insert(v.begin(), v.end()); + } else if (allowlist_table->contains(s)) { + auto v = allowlist_table->at(s); + allowlist.insert(v.begin(), v.end()); } else if (!s.empty()) { // Should be a user provided TF operation. - whitelist.insert(string(s)); + allowlist.insert(string(s)); } } - if (VLOG_IS_ON(2) && !whitelist.empty()) { - std::vector vwhitelist(whitelist.begin(), whitelist.end()); - absl::c_sort(vwhitelist); + if (VLOG_IS_ON(2) && !allowlist.empty()) { + std::vector vallowlist(allowlist.begin(), allowlist.end()); + absl::c_sort(vallowlist); VLOG(2) << "XLA clustering will only consider the following TF operations: " - << absl::StrJoin(vwhitelist, " "); + << absl::StrJoin(vallowlist, " "); } - return whitelist; + return allowlist; } Status MarkForCompilationPassImpl::FindCompilationCandidates() { @@ -1156,12 +1156,12 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { VLOG(2) << "sorted_nodes.size() = " << sorted_nodes.size(); - auto whitelist = GetOrCreateWhitelist(); + auto allowlist = GetOrCreateAllowlist(); std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that user's provided TF operation really exists. - for (const auto& s : whitelist) { + for (const auto& s : allowlist) { if (!all_ops.contains(string(s))) { return errors::InvalidArgument( "The operation '", s, @@ -1206,7 +1206,7 @@ Status MarkForCompilationPassImpl::FindCompilationCandidates() { continue; } - if (!whitelist.empty() && !whitelist.contains(node->def().op())) { + if (!allowlist.empty() && !allowlist.contains(node->def().op())) { VLOG(1) << "Rejecting TF operation " << node->def().op() << " as it is not listed in --tf_xla_ops_to_cluster."; continue; @@ -1781,7 +1781,7 @@ Status MarkForCompilationPass::RunForTest( return MarkForCompilation(options, debug_options); } -absl::flat_hash_map>* GetWhitelistTable() { +absl::flat_hash_map>* GetAllowlistTable() { // Table format: category name: {list of TF operations in that category} static absl::flat_hash_map>* result = new absl::flat_hash_map>{ @@ -1845,7 +1845,7 @@ absl::flat_hash_map>* GetWhitelistTable() { namespace testing { void ResetClusterSequenceNumber() { cluster_sequence_num = 0; } -absl::flat_hash_set GetKnownXLAWhitelistOp() { +absl::flat_hash_set GetKnownXLAAllowlistOp() { absl::flat_hash_set result{"AdjustContrastv2", "AdjustHue", "AdjustSaturation", diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass.h b/tensorflow/compiler/jit/mark_for_compilation_pass.h index 8b660710898..0e9a64e7f28 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass.h +++ b/tensorflow/compiler/jit/mark_for_compilation_pass.h @@ -58,7 +58,7 @@ bool IsCompilable(FunctionLibraryRuntime* flr, const NodeDef& ndef, RecursiveCompilabilityChecker::UncompilableNodesMap* uncompilable_node_info = nullptr); -absl::flat_hash_map>* GetWhitelistTable(); +absl::flat_hash_map>* GetAllowlistTable(); namespace testing { // DO NOT USE IN PRODUCTION. @@ -66,8 +66,8 @@ namespace testing { // Resets some internal state to let us write reliable unit tests. void ResetClusterSequenceNumber(); -// Return a list of operation that we choose not to put into the whitelist. -absl::flat_hash_set GetKnownXLAWhitelistOp(); +// Return a list of operation that we choose not to put into the allowlist. +absl::flat_hash_set GetKnownXLAAllowlistOp(); } // namespace testing } // namespace tensorflow diff --git a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc index 0e1cc2d19fe..3ae72eb514c 100644 --- a/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc +++ b/tensorflow/compiler/jit/mark_for_compilation_pass_test.cc @@ -1802,34 +1802,34 @@ TEST(XlaCompilationTest, StagePipelinePreservedByClusterScopingPass) { EXPECT_NE(clusters["relu0"], clusters["relu1"]); } } -TEST(XlaCompilationTest, XLALiteWhitelist) { - auto* whitelist_table = tensorflow::GetWhitelistTable(); - absl::flat_hash_set hwhitelist; +TEST(XlaCompilationTest, XLALiteAllowlist) { + auto* allowlist_table = tensorflow::GetAllowlistTable(); + absl::flat_hash_set hallowlist; std::vector vall_ops = XlaOpRegistry::GetAllRegisteredOps(); absl::flat_hash_set all_ops(vall_ops.begin(), vall_ops.end()); // Check that all the operations in the table are existing TF operations - for (auto pair : *whitelist_table) { - hwhitelist.insert(pair.second.begin(), pair.second.end()); + for (auto pair : *allowlist_table) { + hallowlist.insert(pair.second.begin(), pair.second.end()); for (auto op : pair.second) { ASSERT_TRUE(all_ops.contains(op)); } } - // Check that all registered XLA operation are in the whitelist + // Check that all registered XLA operation are in the allowlist // table or are known to not be in it. absl::flat_hash_set known_not_in_list = - tensorflow::testing::GetKnownXLAWhitelistOp(); + tensorflow::testing::GetKnownXLAAllowlistOp(); std::vector unknow_op; for (string op : vall_ops) { - if (!hwhitelist.contains(op) && !known_not_in_list.contains(op)) { + if (!hallowlist.contains(op) && !known_not_in_list.contains(op)) { unknow_op.push_back(op); } } EXPECT_TRUE(unknow_op.empty()) << "Someone added support for a new TF opeations inside XLA. They must " - "be included in the XLALite whitelist or blacklist:\n" + "be included in the XLALite allowlist or blacklist:\n" << absl::StrJoin(unknow_op, "\n"); } } // namespace diff --git a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h index 83ff9971246..92c45b98ea7 100644 --- a/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h +++ b/tensorflow/compiler/mlir/lite/common/tfl_pass_config.h @@ -30,7 +30,7 @@ struct PassConfig { explicit PassConfig(QuantizationSpecs specs) : emit_builtin_tflite_ops(true), lower_tensor_list_ops(false), - trim_functions_whitelist({}), + trim_functions_allowlist({}), quant_specs(std::move(specs)), form_clusters(false), unfold_batch_matmul(true), @@ -44,8 +44,8 @@ struct PassConfig { // If `lower_tensor_list_ops` is true, tensorlist ops will be lowered to basic // TF ops before legalization to TF Lite dialect. bool lower_tensor_list_ops; - // The whitelist of functions that would be preserved after trimming. - llvm::ArrayRef trim_functions_whitelist; + // The allowlist of functions that would be preserved after trimming. + llvm::ArrayRef trim_functions_allowlist; // All information about quantization. QuantizationSpecs quant_specs; // If `form_clusters` is true , clusters are formed by grouping consecutive diff --git a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc index ee8b34598e2..fb20e842a75 100644 --- a/tensorflow/compiler/mlir/lite/flatbuffer_export.cc +++ b/tensorflow/compiler/mlir/lite/flatbuffer_export.cc @@ -71,7 +71,7 @@ limitations under the License. #include "tensorflow/core/platform/errors.h" #include "tensorflow/core/platform/logging.h" #include "tensorflow/core/platform/status.h" -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" +#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/kernels/internal/kernel_utils.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/string_util.h" @@ -101,7 +101,7 @@ using mlir::Value; using tensorflow::OpOrArgLocNameMapper; using tensorflow::OpOrArgNameMapper; using tensorflow::Status; -using tflite::flex::IsWhitelistedFlexOp; +using tflite::flex::IsAllowlistedFlexOp; using xla::StatusOr; template @@ -972,7 +972,7 @@ Optional> Translator::BuildOperator( // model is of an open op system. // // The following algorithm is followed: - // if flex is enabled and the op is whitelisted as flex + // if flex is enabled and the op is allowlisted as flex // we emit op as flex. // if custom is enabled // we emit the op as custom. @@ -982,11 +982,11 @@ Optional> Translator::BuildOperator( } // Flex op case - // Eventually, the whitelist will go away and we will rely on some TF op + // Eventually, the allowlist will go away and we will rely on some TF op // trait (e.g. No side effect) to determine if it is a supported "Flex" // op or not. if (enabled_op_types_.contains(OpType::kSelectTf) && - IsWhitelistedFlexOp(node_def->op())) { + IsAllowlistedFlexOp(node_def->op())) { // Construct ops as flex op encoding TensorFlow node definition // as custom options. // Flex ops are named with the kFlexOpNamePrefix prefix to the actual @@ -1037,7 +1037,7 @@ Optional> Translator::BuildOperator( } // Insert failed op to `flex_ops` or `custom_ops`. - if (IsWhitelistedFlexOp(node_def->op())) { + if (IsAllowlistedFlexOp(node_def->op())) { failed_flex_ops_.insert(os.str()); } else { failed_custom_ops_.insert(os.str()); diff --git a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir index 05eb8de71e9..53caf15bc8f 100644 --- a/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir +++ b/tensorflow/compiler/mlir/lite/tests/prepare-quantize.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-whitelist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s +// RUN: tf-opt %s -tfl-prepare-quantize -tfl-test-quantize-allowlist="quantize_float_placeholder_only,not_reset_input" | FileCheck %s // CHECK-LABEL: quantize_float_placeholder_only func @quantize_float_placeholder_only(%arg0: tensor, %arg1: tensor<2x3xi32>, %arg2: tensor<2x3xf32>) -> (tensor, tensor<2x3xi32>, tensor<2x3xf32>) { diff --git a/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir index 0087ae12156..0b8c147cde2 100644 --- a/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir +++ b/tensorflow/compiler/mlir/lite/tests/trim-functions-tf.mlir @@ -1,4 +1,4 @@ -// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-whitelist="bar,foobar" %s | FileCheck %s +// RUN: tf-opt -tfl-trim-funcs-tf -tfl-trim-funcs-allowlist="bar,foobar" %s | FileCheck %s func @foo(%arg0: tensor<1x4xf32>, %arg1: tensor<1x4xf32>) -> tensor<1x4xf32> { return %arg0 : tensor<1x4xf32> diff --git a/tensorflow/compiler/mlir/lite/transforms/passes.h b/tensorflow/compiler/mlir/lite/transforms/passes.h index 105c9394fb4..af97931b2a3 100644 --- a/tensorflow/compiler/mlir/lite/transforms/passes.h +++ b/tensorflow/compiler/mlir/lite/transforms/passes.h @@ -61,7 +61,7 @@ std::unique_ptr> CreatePostQuantizePass( // Creates an instance of the TensorFlow Lite dialect TrimFunctions // pass. std::unique_ptr> CreateTrimFunctionsPass( - llvm::ArrayRef trim_funcs_whitelist); + llvm::ArrayRef trim_funcs_allowlist); // Creates an instance of the TensorFlow Lite dialect PrepareCompositeFunctions // pass. diff --git a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc index 579063f9c9d..9a27d0de62a 100644 --- a/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc +++ b/tensorflow/compiler/mlir/lite/transforms/prepare_quantize.cc @@ -35,9 +35,9 @@ limitations under the License. #include "tensorflow/core/framework/types.pb.h" // NOLINTNEXTLINE -static llvm::cl::list quantize_whitelist( - "tfl-test-quantize-whitelist", llvm::cl::value_desc("list"), - llvm::cl::desc("comma separated list of whitelisted functions to be " +static llvm::cl::list quantize_allowlist( + "tfl-test-quantize-allowlist", llvm::cl::value_desc("list"), + llvm::cl::desc("comma separated list of allowlisted functions to be " "quantized. Only used in tests"), llvm::cl::CommaSeparated); @@ -108,7 +108,7 @@ class PrepareQuantizePass // Get the min and max values from the quantization specification for the // current function function and argument index. Uses default values if - // the function is specified in the `quantize_whitelist`. + // the function is specified in the `quantize_allowlist`. std::pair, llvm::Optional> GetMinMaxValuesForArgument(llvm::StringRef func_name, int index) { if (func_name == quant_specs_.target_func) { @@ -132,7 +132,7 @@ bool PrepareQuantizePass::SetInputNodesQuantizationParams(FuncOp func) { // Skip this function because it isn't the target function from the spec or // in the function while list. if (target_func != func_name && - !llvm::is_contained(quantize_whitelist, func_name)) { + !llvm::is_contained(quantize_allowlist, func_name)) { return false; } diff --git a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc index 013ffc26ea8..9eedf2b4fa6 100644 --- a/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc +++ b/tensorflow/compiler/mlir/lite/transforms/trim_functions_tf.cc @@ -29,12 +29,12 @@ limitations under the License. #include "mlir/Pass/Pass.h" // from @llvm-project #include "tensorflow/compiler/mlir/lite/transforms/passes.h" -// The cmd line flag to specify the whitelist of functions. Rest are trimmed +// The cmd line flag to specify the allowlist of functions. Rest are trimmed // after this pass is run. // NOLINTNEXTLINE -static llvm::cl::list trim_funcs_whitelist( - "tfl-trim-funcs-whitelist", llvm::cl::value_desc("list"), - llvm::cl::desc("comma separated list of whitelisted functions. The first " +static llvm::cl::list trim_funcs_allowlist( + "tfl-trim-funcs-allowlist", llvm::cl::value_desc("list"), + llvm::cl::desc("comma separated list of allowlisted functions. The first " "function specified will be used as main."), llvm::cl::CommaSeparated); @@ -43,25 +43,25 @@ namespace TFL { namespace { // The pass to trim functions before we legalize to TFL -// dialect using the specified whitelist. +// dialect using the specified allowlist. class TrimFunctionsPass : public mlir::PassWrapper> { public: - explicit TrimFunctionsPass() : trim_funcs_whitelist_(trim_funcs_whitelist) {} - explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_whitelist) - : trim_funcs_whitelist_(trim_funcs_whitelist) {} + explicit TrimFunctionsPass() : trim_funcs_allowlist_(trim_funcs_allowlist) {} + explicit TrimFunctionsPass(llvm::ArrayRef trim_funcs_allowlist) + : trim_funcs_allowlist_(trim_funcs_allowlist) {} private: void runOnOperation() override; bool TrimModule(); void Verify(); - llvm::ArrayRef trim_funcs_whitelist_; + llvm::ArrayRef trim_funcs_allowlist_; }; void TrimFunctionsPass::runOnOperation() { - // trim the functions in the module using the trim_funcs_whitelist_ - // by removing functions not in the whitelist. + // trim the functions in the module using the trim_funcs_allowlist_ + // by removing functions not in the allowlist. if (TrimModule()) { // verify the updated module is still valid, if not signal the // pass as failed. @@ -70,20 +70,20 @@ void TrimFunctionsPass::runOnOperation() { } bool TrimFunctionsPass::TrimModule() { - // if no trim_funcs_whitelist_ is specified, this pass is a no-op. - if (trim_funcs_whitelist_.empty()) return false; + // if no trim_funcs_allowlist_ is specified, this pass is a no-op. + if (trim_funcs_allowlist_.empty()) return false; llvm::SmallVector funcs_to_trim; for (auto func : getOperation().getOps()) { - if (llvm::is_contained(trim_funcs_whitelist_, func.getName())) { - // If no main is specified in the whitelist, use the 1st func - // in trim_funcs_whitelist as the main. + if (llvm::is_contained(trim_funcs_allowlist_, func.getName())) { + // If no main is specified in the allowlist, use the 1st func + // in trim_funcs_allowlist as the main. // TODO(ashwinm): Currently tflite flatbuffer export assumes there is // always a main. This is strictly not required for TFlite. We need to // remove that restriction once we have support to attribute the main // tensorflow function in MLIR TF import using an entry_point attr. - if (!llvm::is_contained(trim_funcs_whitelist_, "main") && - func.getName() == trim_funcs_whitelist_[0]) { + if (!llvm::is_contained(trim_funcs_allowlist_, "main") && + func.getName() == trim_funcs_allowlist_[0]) { func.setName("main"); } } else { @@ -99,7 +99,7 @@ bool TrimFunctionsPass::TrimModule() { } // validate that all reachable functions from the remaining functions are -// also in the whitelist. +// also in the allowlist. void TrimFunctionsPass::Verify() { // TODO(ashwinm): Instead, we should make sure that references to all // SymbolRefAttrs of all ops are present. @@ -109,7 +109,7 @@ void TrimFunctionsPass::Verify() { auto walk_result = func.walk([&](CallOp op) -> WalkResult { if (!symbol_table.lookup(op.getCallee())) return getOperation().emitError() - << func.getName() << " is not in the funcs whitelist"; + << func.getName() << " is not in the funcs allowlist"; return WalkResult::advance(); }); if (walk_result.wasInterrupted()) return signalPassFailure(); @@ -121,13 +121,13 @@ void TrimFunctionsPass::Verify() { // Creates an instance of the TensorFlow Lite dialect TrimFunctions /// pass. std::unique_ptr> CreateTrimFunctionsPass( - llvm::ArrayRef trim_funcs_whitelist) { - return std::make_unique(trim_funcs_whitelist); + llvm::ArrayRef trim_funcs_allowlist) { + return std::make_unique(trim_funcs_allowlist); } static PassRegistration pass( "tfl-trim-funcs-tf", - "Trim functions to restrict them to a specified whitelist prior to " + "Trim functions to restrict them to a specified allowlist prior to " "legalization to TensorFlow lite dialect"); } // namespace TFL diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir index 95460dcf998..53747128af0 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf-with-tf2xla.mlir @@ -23,8 +23,8 @@ func @unknown_op(%arg0: tensor<2xf32>) -> tensor<2xf32> { return %0 : tensor<2xf32> } -// CHECK-LABEL: not_whitelisted_op -func @not_whitelisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { +// CHECK-LABEL: not_allowlisted_op +func @not_allowlisted_op(%arg0: tensor<3xi32>, %arg1: tensor, %arg2: tensor) -> tensor { // CHECK: tf.TensorListReserve %0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<3xi32>, tensor) -> tensor>> // CHECK: tf.TensorListGetItem diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 8090c0bd097..916e7db33e3 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -75,10 +75,10 @@ namespace { template using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok -static bool IsOpWhitelisted(Operation* op) { +static bool IsOpAllowlisted(Operation* op) { // White-listed TensorFlow ops are known to have well behaved tf2xla kernels // building valid MLIR using MlirHloBuilder. - // TODO(hinsu): Drop explicit whitelist when MLIR based bridge is enabled for + // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for // all tf2xla kernels. // clang-format off static llvm::SmallDenseSet ops = { @@ -342,7 +342,7 @@ LogicalResult FuncLegalizer::Legalize() { } LogicalResult FuncLegalizer::LegalizeOp(Operation* op) { - if (!IsOpWhitelisted(op)) return success(); + if (!IsOpAllowlisted(op)) return success(); // Only static shaped operands are supported in XLA builders for now. for (Type ty : op->getOperandTypes()) { diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.cc b/tensorflow/compiler/tf2xla/xla_op_registry.cc index a43608bd434..e37f4659185 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.cc +++ b/tensorflow/compiler/tf2xla/xla_op_registry.cc @@ -63,7 +63,7 @@ XlaOpRegistry::~XlaOpRegistry() = default; if (x.name != y.name) return true; if (x.label != y.label) return true; // The registrations refer to the same Op: ensures they are compatible and - // are restricted to different device whitelists. + // are restricted to different device allowlists. if (x.compilation_only != y.compilation_only) { LOG(WARNING) << "Registrations of " << x.name << " have incompatible compilation_only settings."; @@ -84,14 +84,14 @@ XlaOpRegistry::~XlaOpRegistry() = default; << " have incompatible allow_string_type settings."; return false; } - if (!x.has_device_whitelist && !y.has_device_whitelist) { + if (!x.has_device_allowlist && !y.has_device_allowlist) { LOG(WARNING) << "Duplicate registrations of " << x.name - << "with no device whitelists."; + << "with no device allowlists."; return false; } - if (x.has_device_whitelist && y.has_device_whitelist) { - for (const auto& device : x.device_whitelist) { - if (y.device_whitelist.count(device) != 0) { + if (x.has_device_allowlist && y.has_device_allowlist) { + for (const auto& device : x.device_allowlist) { + if (y.device_allowlist.count(device) != 0) { LOG(WARNING) << "Multiple registrations of " << x.name << " on device " << device; return false; @@ -185,28 +185,28 @@ void XlaOpRegistry::RegisterCompilationKernels() { // The goal is to allow the co-existence of backend-specific kernels and // generic kernels. To achieve this, we enforce the following order of // registrations for one op: - // 1. Process op registration with device whitelists: + // 1. Process op registration with device allowlists: // this pass registers backend-specific kernels for this op. - // 2. Process op registration without device whitelists: + // 2. Process op registration without device allowlists: // this pass registers the kernels for all the other supported backends. for (auto& ops : registry.ops_) { const string& op_name = ops.first; std::vector>& op_registrations = ops.second; - // Partition the op registration so that the ones with device whitelists - // precede the one without device whitelist. + // Partition the op registration so that the ones with device allowlists + // precede the one without device allowlist. std::partition(op_registrations.begin(), op_registrations.end(), [](const std::unique_ptr& op_reg) { - return op_reg->has_device_whitelist; + return op_reg->has_device_allowlist; }); - // Collect a set of backend registered by ops with device whitelists. - // The op registration without whitelists will register a generic kernel + // Collect a set of backend registered by ops with device allowlists. + // The op registration without allowlists will register a generic kernel // for all other backends not in this set. - std::unordered_set whitelisted_backend; + std::unordered_set allowlisted_backend; for (auto& op_registration : op_registrations) { - if (op_registration->has_device_whitelist) { - whitelisted_backend.insert(op_registration->device_whitelist.begin(), - op_registration->device_whitelist.end()); + if (op_registration->has_device_allowlist) { + allowlisted_backend.insert(op_registration->device_allowlist.begin(), + op_registration->device_allowlist.end()); } } @@ -238,19 +238,19 @@ void XlaOpRegistry::RegisterCompilationKernels() { } for (auto& backend : registry.backends_) { - // If the operator has a device whitelist, only register on whitelisted + // If the operator has a device allowlist, only register on allowlisted // devices. - if (op_registration->has_device_whitelist && - op_registration->device_whitelist.find(backend.first) == - op_registration->device_whitelist.end()) { + if (op_registration->has_device_allowlist && + op_registration->device_allowlist.find(backend.first) == + op_registration->device_allowlist.end()) { continue; } - // If the operator does NOT has a device whitelist, skip all devices + // If the operator does NOT has a device allowlist, skip all devices // that has already been registered. - if (!op_registration->has_device_whitelist && - whitelisted_backend.find(backend.first) != - whitelisted_backend.end()) { + if (!op_registration->has_device_allowlist && + allowlisted_backend.find(backend.first) != + allowlisted_backend.end()) { continue; } @@ -478,17 +478,17 @@ XlaOpRegistrationBuilder XlaOpRegistrationBuilder::Name( XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( absl::Span devices) { - registration_->has_device_whitelist = true; + registration_->has_device_allowlist = true; for (absl::string_view device : devices) { - registration_->device_whitelist.emplace(device); + registration_->device_allowlist.emplace(device); } return *this; } XlaOpRegistrationBuilder& XlaOpRegistrationBuilder::Device( absl::string_view device) { - registration_->has_device_whitelist = true; - registration_->device_whitelist.emplace(device); + registration_->has_device_allowlist = true; + registration_->device_allowlist.emplace(device); return *this; } diff --git a/tensorflow/compiler/tf2xla/xla_op_registry.h b/tensorflow/compiler/tf2xla/xla_op_registry.h index 7839ae95dc0..af720fb4bb9 100644 --- a/tensorflow/compiler/tf2xla/xla_op_registry.h +++ b/tensorflow/compiler/tf2xla/xla_op_registry.h @@ -258,10 +258,10 @@ class XlaOpRegistry { // Mapping from attribute name to a list of supported types. std::unordered_map> type_constraints; - // An optional whitelist of devices. If there is no whitelist, all devices + // An optional allowlist of devices. If there is no allowlist, all devices // are permitted. - bool has_device_whitelist = false; - std::unordered_set device_whitelist; + bool has_device_allowlist = false; + std::unordered_set device_allowlist; // Names of arguments that must be compile-time constants. std::unordered_set compile_time_constant_inputs; @@ -279,8 +279,8 @@ class XlaOpRegistry { // Returns true if registrations x and y can both be added to the registry. // This is always the case if they refer to different ops. If they refer to // the same op name, they must: have the same values for compilation_only, - // allow_resource_types and allow_variant_types; use a device_whitelist; and - // their whitelists must not intersect. + // allow_resource_types and allow_variant_types; use a device_allowlist; and + // their allowlists must not intersect. static bool IsCompatible(const OpRegistration& x, const OpRegistration& y); static Status CompileTimeConstantInputs(const NodeDef& node_def, @@ -319,7 +319,7 @@ class XlaOpRegistrationBuilder { // Starts an operator registration chain. static XlaOpRegistrationBuilder Name(absl::string_view name); - // Specifies a whitelist of devices on which the operator may run. + // Specifies a allowlist of devices on which the operator may run. XlaOpRegistrationBuilder& Device(absl::string_view devices); XlaOpRegistrationBuilder& Device(absl::Span devices); diff --git a/tensorflow/core/common_runtime/graph_execution_state.cc b/tensorflow/core/common_runtime/graph_execution_state.cc index ebe96d5dbd6..a7f17ba6a0d 100644 --- a/tensorflow/core/common_runtime/graph_execution_state.cc +++ b/tensorflow/core/common_runtime/graph_execution_state.cc @@ -378,7 +378,7 @@ struct TensorAndDevice { }; // Tensors of some DataTypes cannot placed in device memory as feeds or -// fetches. Validate against a whitelist of those known to work. +// fetches. Validate against a allowlist of those known to work. bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) { // The mechanism for supporting feeds of device-backed Tensors requires // the _Arg kernel to be registered for the corresponding type (and that @@ -391,7 +391,7 @@ bool IsFeedAndFetchSupported(DataType dtype, const string& device_type) { // For now, we return true iff there are _Arg AND _Retval kernels for dtype on // the device. False negatives are okay, false positives would be bad. // - // TODO(ashankar): Instead of a whitelist here, perhaps we could query + // TODO(ashankar): Instead of a allowlist here, perhaps we could query // the kernel registry for _Arg and _Retval kernels instead. if (device_type == DEVICE_CPU) return true; if (device_type != DEVICE_GPU) return false; diff --git a/tensorflow/core/framework/dataset_stateful_op_whitelist.h b/tensorflow/core/framework/dataset_stateful_op_whitelist.h index 0e4f018b73e..aef66b7ed85 100644 --- a/tensorflow/core/framework/dataset_stateful_op_whitelist.h +++ b/tensorflow/core/framework/dataset_stateful_op_whitelist.h @@ -13,8 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ -#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ +#ifndef TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_ +#define TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_ #include #include "tensorflow/core/lib/core/status.h" @@ -23,7 +23,7 @@ namespace tensorflow { namespace data { // Registry for stateful ops that need to be used in dataset functions. // See below macro for usage details. -class WhitelistedStatefulOpRegistry { +class AllowlistedStatefulOpRegistry { public: Status Add(string op_name) { op_names_.insert(std::move(op_name)); @@ -37,29 +37,29 @@ class WhitelistedStatefulOpRegistry { bool Contains(const string& op_name) { return op_names_.count(op_name); } - static WhitelistedStatefulOpRegistry* Global() { - static auto* reg = new WhitelistedStatefulOpRegistry; + static AllowlistedStatefulOpRegistry* Global() { + static auto* reg = new AllowlistedStatefulOpRegistry; return reg; } private: - WhitelistedStatefulOpRegistry() = default; - WhitelistedStatefulOpRegistry(WhitelistedStatefulOpRegistry const& copy) = + AllowlistedStatefulOpRegistry() = default; + AllowlistedStatefulOpRegistry(AllowlistedStatefulOpRegistry const& copy) = delete; - WhitelistedStatefulOpRegistry operator=( - WhitelistedStatefulOpRegistry const& copy) = delete; + AllowlistedStatefulOpRegistry operator=( + AllowlistedStatefulOpRegistry const& copy) = delete; std::unordered_set op_names_; }; } // namespace data -// Use this macro to whitelist an op that is marked stateful but needs to be +// Use this macro to allowlist an op that is marked stateful but needs to be // used inside a map_fn in an input pipeline. This is only needed if you wish // to be able to checkpoint the state of the input pipeline. We currently // do not allow stateful ops to be defined inside of map_fns since it is not // possible to save their state. -// Note that the state of the whitelisted ops inside functions will not be +// Note that the state of the allowlisted ops inside functions will not be // saved during checkpointing, hence this should only be used if the op is // marked stateful for reasons like to avoid constant folding during graph // optimization but is not stateful. @@ -73,9 +73,9 @@ class WhitelistedStatefulOpRegistry { #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ_HELPER(ctr, name) \ WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) #define WHITELIST_STATEFUL_OP_FOR_DATASET_FUNCTIONS_UNIQ(ctr, name) \ - static ::tensorflow::Status whitelist_op##ctr TF_ATTRIBUTE_UNUSED = \ - ::tensorflow::data::WhitelistedStatefulOpRegistry::Global()->Add(name) + static ::tensorflow::Status allowlist_op##ctr TF_ATTRIBUTE_UNUSED = \ + ::tensorflow::data::AllowlistedStatefulOpRegistry::Global()->Add(name) } // namespace tensorflow -#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_WHITELIST_H_ +#endif // TENSORFLOW_CORE_FRAMEWORK_DATASET_STATEFUL_OP_ALLOWLIST_H_ diff --git a/tensorflow/core/grappler/costs/graph_properties.cc b/tensorflow/core/grappler/costs/graph_properties.cc index 36e530916a3..c3df2c1f15b 100644 --- a/tensorflow/core/grappler/costs/graph_properties.cc +++ b/tensorflow/core/grappler/costs/graph_properties.cc @@ -542,8 +542,8 @@ bool IsNumericType(const DataType dtype) { return kRealNumberTypes->find(dtype) != kRealNumberTypes->end(); } -bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) { - static const gtl::FlatSet* const kOpTpeWhitelist = +bool IsAllowListedOpTypeForEvaluateNode(const string& op_type) { + static const gtl::FlatSet* const kOpTpeAllowlist = CHECK_NOTNULL((new gtl::FlatSet{ // Unary arithmetic ops "Floor", @@ -589,7 +589,7 @@ bool IsWhiteListedOpTypeForEvaluateNode(const string& op_type) { "Fill", "Cast", })); - return kOpTpeWhitelist->find(op_type) != kOpTpeWhitelist->end(); + return kOpTpeAllowlist->find(op_type) != kOpTpeAllowlist->end(); } // Negative shape size of '-1' represents unknown, while negative shape sizes @@ -1441,7 +1441,7 @@ class SymbolicShapeRefiner { // Due to the cost of running EvaluateNode(), we limit only to white listed // op types. - if (!IsWhiteListedOpTypeForEvaluateNode(c->op_data->op_def.name())) { + if (!IsAllowListedOpTypeForEvaluateNode(c->op_data->op_def.name())) { return false; } diff --git a/tensorflow/core/grappler/costs/graph_properties_test.cc b/tensorflow/core/grappler/costs/graph_properties_test.cc index 4137d4315bc..9fd43cdee12 100644 --- a/tensorflow/core/grappler/costs/graph_properties_test.cc +++ b/tensorflow/core/grappler/costs/graph_properties_test.cc @@ -1008,7 +1008,7 @@ TEST_F(GraphPropertiesTest, IdentityPassingShape) { TEST_F(GraphPropertiesTest, SkippingValueInferenceForLargeTensors) { // When using aggressive_shape_inference, we run EvaluateNode() for - // whitelisted ops and small input / output tensors. For instance, Fill op is + // allowlisted ops and small input / output tensors. For instance, Fill op is // evaluated and produces output tensor value if output tensor size is smal // (currently, fewer than 17 elements); otherwise we don't run EvaluateNode(). // This is to avoid wasting time and memory for producing huge tensors (e.g., diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc index f20c4eea0c9..252eb3c885c 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision.cc @@ -842,11 +842,11 @@ DataTypeSet AllowedDataTypes(const OpDef& op_def, const TypeAttrId& t_attr_id) { return AllowedDataTypes(*attr_def); } -Status ValidateLists(const gtl::FlatSet& white_list, +Status ValidateLists(const gtl::FlatSet& allow_list, const gtl::FlatSet& black_list, const gtl::FlatSet& gray_list, const gtl::FlatSet& clear_list) { - std::vector> lists{white_list, black_list, gray_list, + std::vector> lists{allow_list, black_list, gray_list, clear_list}; std::multiset counts; for (const auto& list : lists) { @@ -973,25 +973,25 @@ class AutoMixedPrecisionImpl { void FindTensorListImplicitFloat32Edges( const absl::flat_hash_set& tensor_list_nodes, std::vector* implicit_data_edges) const; - void AddWhitelistOps(absl::flat_hash_set* white_set) const; + void AddAllowlistOps(absl::flat_hash_set* allow_set) const; void PropagateBlackFwdThroughClearAndGray( absl::flat_hash_set* black_set) const; void ForceColorMatchBetweenTensorListOps( const absl::flat_hash_set& tensor_list_nodes, - absl::flat_hash_set* white_set, + absl::flat_hash_set* allow_set, absl::flat_hash_set* black_set) const; - void AddClearAndGrayToWhiteIfBetweenWhite( + void AddClearAndGrayToAllowIfBetweenAllow( const absl::flat_hash_set& black_set, - absl::flat_hash_set* white_set) const; - void PropagateWhiteThroughClear(const absl::flat_hash_set& black_set, - absl::flat_hash_set* white_set) const; + absl::flat_hash_set* allow_set) const; + void PropagateAllowThroughClear(const absl::flat_hash_set& black_set, + absl::flat_hash_set* allow_set) const; Status ForceColorMatchOnRecurrentEdges( - absl::flat_hash_set* white_set) const; - void MakeCastsWhiteIfAllOutputsWhite( - absl::flat_hash_set* white_set) const; + absl::flat_hash_set* allow_set) const; + void MakeCastsAllowIfAllOutputsAllow( + absl::flat_hash_set* allow_set) const; NodeDef BuildCastNode(const MutableGraphView::OutputPort& src, bool to_f16, const string& device) const; - Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set& white_set); + Status ChangeTypeAttrsAndAddCasts(const absl::flat_hash_set& allow_set); VirtualPlacer virtual_placer_; std::unordered_set nodes_to_preserve_; @@ -1005,7 +1005,7 @@ class AutoMixedPrecisionImpl { GraphTypeTopologyView graph_type_view_; bool force_all_fp16_; AutoMixedPrecisionMode mode_; - gtl::FlatSet f16_whitelist_; + gtl::FlatSet f16_allowlist_; gtl::FlatSet f16_blacklist_; gtl::FlatSet f16_graylist_; gtl::FlatSet f16_clearlist_; @@ -1079,8 +1079,8 @@ Status AutoMixedPrecisionImpl::PrintDebugLogs(bool preop, size_t timestamp) { f.open(fname.c_str(), std::fstream::out); std::unique_ptr mp_lists = get_mixed_precision_lists(); - f << "WhiteList:\n"; - for (const auto& x : mp_lists->WhiteList()) { + f << "AllowList:\n"; + for (const auto& x : mp_lists->AllowList()) { f << x << "\n"; } f << "\nBlackList:\n"; @@ -1254,11 +1254,11 @@ Status AutoMixedPrecisionImpl::Optimize() { std::unique_ptr mp_lists = get_mixed_precision_lists(); - f16_whitelist_ = mp_lists->WhiteList(); + f16_allowlist_ = mp_lists->AllowList(); f16_blacklist_ = mp_lists->BlackList(); f16_graylist_ = mp_lists->GrayList(); f16_clearlist_ = mp_lists->ClearList(); - TF_RETURN_IF_ERROR(ValidateLists(f16_whitelist_, f16_blacklist_, + TF_RETURN_IF_ERROR(ValidateLists(f16_allowlist_, f16_blacklist_, f16_graylist_, f16_clearlist_)); size_t timestamp = Env::Default()->NowMicros() / 1000; @@ -1316,8 +1316,8 @@ Status AutoMixedPrecisionImpl::Optimize() { // boundaries between f16/non-f16 nodes. // The algorithm for deciding which nodes to change to f16 is as follows: - // 1) Add all performance-critical ops (aka "whitelist" ops) to the white_set. - // This is done under the assumption that whitelist ops are always + // 1) Add all performance-critical ops (aka "allowlist" ops) to the allow_set. + // This is done under the assumption that allowlist ops are always // numerically-safe in f16 and that they are the most important ops for // improving performance. // 2) Add nodes to the black_set iff they are numerically-dangerous (aka @@ -1329,20 +1329,20 @@ Status AutoMixedPrecisionImpl::Optimize() { // numerical accuracy of the model. // 3) For all remaining nodes that are not considered dangerous (greylist // and clearlist ops), find those that are between (i.e., both upstream - // and downstream of) white nodes, and add them to the white_set. - // This is done to avoid unnecessary casts between whitelist ops. - // 4) For all remaining clearlist nodes, add them to the white_set if they are - // connected to a node in the white_set via other clearlist nodes. - // This is done to increase the number of ops in the white_set without + // and downstream of) allow nodes, and add them to the allow_set. + // This is done to avoid unnecessary casts between allowlist ops. + // 4) For all remaining clearlist nodes, add them to the allow_set if they are + // connected to a node in the allow_set via other clearlist nodes. + // This is done to increase the number of ops in the allow_set without // affecting numerical stability. - absl::flat_hash_set white_set; - VLOG(2) << "Beginning pass 1 to add whitelist ops"; - AddWhitelistOps(&white_set); + absl::flat_hash_set allow_set; + VLOG(2) << "Beginning pass 1 to add allowlist ops"; + AddAllowlistOps(&allow_set); VLOG(2) << "Finished pass 1"; - if (white_set.empty()) { - LOG(INFO) << "No whitelist ops found, nothing to do"; + if (allow_set.empty()) { + LOG(INFO) << "No allowlist ops found, nothing to do"; return Status::OK(); } @@ -1353,33 +1353,33 @@ Status AutoMixedPrecisionImpl::Optimize() { VLOG(2) << "Forcing color match between data structure ops"; for (const auto& cluster : tensor_list_clusters) { - ForceColorMatchBetweenTensorListOps(cluster, &white_set, &black_set); + ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set); } - VLOG(2) << "Beginning pass 3 to set clear and gray nodes to white if they " - "are between white ops"; - AddClearAndGrayToWhiteIfBetweenWhite(black_set, &white_set); + VLOG(2) << "Beginning pass 3 to set clear and gray nodes to allow if they " + "are between allow ops"; + AddClearAndGrayToAllowIfBetweenAllow(black_set, &allow_set); VLOG(2) << "Finished pass 3"; - VLOG(2) << "Beginning pass 4 to propagate white from white nodes through " + VLOG(2) << "Beginning pass 4 to propagate allow from allow nodes through " "clearlist ops"; - PropagateWhiteThroughClear(black_set, &white_set); + PropagateAllowThroughClear(black_set, &allow_set); VLOG(2) << "Finished pass 4"; VLOG(2) << "Forcing color match between data structure ops"; for (const auto& cluster : tensor_list_clusters) { - ForceColorMatchBetweenTensorListOps(cluster, &white_set, &black_set); + ForceColorMatchBetweenTensorListOps(cluster, &allow_set, &black_set); } VLOG(2) << "Forcing color match on loop edges"; - TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&white_set)); + TF_RETURN_IF_ERROR(ForceColorMatchOnRecurrentEdges(&allow_set)); - VLOG(2) << "Finding existing casts that can be made white"; - MakeCastsWhiteIfAllOutputsWhite(&white_set); + VLOG(2) << "Finding existing casts that can be made allow"; + MakeCastsAllowIfAllOutputsAllow(&allow_set); VLOG(2) << "Beginning final pass to change type attributes and insert Cast " "ops at paint boundaries"; - TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(white_set)); + TF_RETURN_IF_ERROR(ChangeTypeAttrsAndAddCasts(allow_set)); VLOG(2) << "Finished final pass"; TF_RETURN_IF_ERROR(PrintDebugLogs(/* preop = */ false, timestamp)); @@ -1516,19 +1516,19 @@ void AutoMixedPrecisionImpl::FindTensorListImplicitFloat32Edges( } } -void AutoMixedPrecisionImpl::AddWhitelistOps( - absl::flat_hash_set* white_set) const { - // Add whitelisted ops to white_set. +void AutoMixedPrecisionImpl::AddAllowlistOps( + absl::flat_hash_set* allow_set) const { + // Add allowlisted ops to allow_set. for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); if (!ShouldProcess(*root.node)) continue; - bool force_white = force_all_fp16_ && CanForceFP16(*root.node); - if (f16_whitelist_.count(root.node->op()) || force_white) { - bool inserted = white_set->insert(root_idx).second; + bool force_allow = force_all_fp16_ && CanForceFP16(*root.node); + if (f16_allowlist_.count(root.node->op()) || force_allow) { + bool inserted = allow_set->insert(root_idx).second; if (VLOG_IS_ON(2) && inserted) { VLOG(2) << "Painting type " << root.type_attr.DebugString() - << " of node " << root.node->name() << " WHITE because its op " - << root.node->op() << " is on the whitelist"; + << " of node " << root.node->name() << " ALLOW because its op " + << root.node->op() << " is on the allowlist"; } } } @@ -1537,8 +1537,8 @@ void AutoMixedPrecisionImpl::AddWhitelistOps( // Adds nodes to black_set iff they are on the blacklist or they are on a // forward path from a blacklist node to a black/gray node (including the node // at the end of the path) through clear and gray nodes. -// E.g., black -> gray -> clear -> gray -> clear -> white -> gray -// becomes: black -> black -> black -> black -> clear -> white -> gray. +// E.g., black -> gray -> clear -> gray -> clear -> allow -> gray +// becomes: black -> black -> black -> black -> clear -> allow -> gray. void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray( absl::flat_hash_set* black_set) const { if (force_all_fp16_) return; @@ -1588,14 +1588,14 @@ void AutoMixedPrecisionImpl::PropagateBlackFwdThroughClearAndGray( } } -void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite( +void AutoMixedPrecisionImpl::AddClearAndGrayToAllowIfBetweenAllow( const absl::flat_hash_set& black_set, - absl::flat_hash_set* white_set) const { - // Find clear/graylist ops that are downstream of white ops. - absl::flat_hash_set downstream_of_white_set; + absl::flat_hash_set* allow_set) const { + // Find clear/graylist ops that are downstream of allow ops. + absl::flat_hash_set downstream_of_allow_set; for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); - if (!ShouldProcess(*root.node) || !f16_whitelist_.count(root.node->op())) { + if (!ShouldProcess(*root.node) || !f16_allowlist_.count(root.node->op())) { continue; } DfsTypeTraversal( @@ -1603,8 +1603,8 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite( DfsTypePredicates::Enter([&](int idx) -> bool { const NodeTypeId& item = *graph_type_view_.GetNode(idx); return idx == root_idx || - (!downstream_of_white_set.count(idx) && - !f16_whitelist_.count(item.node->op()) && + (!downstream_of_allow_set.count(idx) && + !f16_allowlist_.count(item.node->op()) && !black_set.count(idx) && ShouldProcess(*item.node) && // TODO(benbarsdell): Consider allowing propagation through // ops that are already float16 in order to reduce the number @@ -1614,45 +1614,45 @@ void AutoMixedPrecisionImpl::AddClearAndGrayToWhiteIfBetweenWhite( f16_graylist_.count(item.node->op()))); }), DfsTypeCallbacks::PreOrder( - [&](int idx) { downstream_of_white_set.insert(idx); })); + [&](int idx) { downstream_of_allow_set.insert(idx); })); } - // Set nodes that are both downstream and upstream of white ops to white. - absl::flat_hash_set upstream_of_white_set; + // Set nodes that are both downstream and upstream of allow ops to allow. + absl::flat_hash_set upstream_of_allow_set; for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); - if (!ShouldProcess(*root.node) || upstream_of_white_set.count(root_idx) || - !f16_whitelist_.count(root.node->op())) { + if (!ShouldProcess(*root.node) || upstream_of_allow_set.count(root_idx) || + !f16_allowlist_.count(root.node->op())) { continue; } DfsTypeTraversal( graph_type_view_, {&root}, TypeTraversalDirection::kFollowInputs, DfsTypePredicates::Enter([&](int idx) -> bool { - return idx == root_idx || (!upstream_of_white_set.count(idx) && - downstream_of_white_set.count(idx)); + return idx == root_idx || (!upstream_of_allow_set.count(idx) && + downstream_of_allow_set.count(idx)); }), DfsTypeCallbacks::PreOrder([&](int idx) { - upstream_of_white_set.insert(idx); - bool inserted = white_set->insert(idx).second; + upstream_of_allow_set.insert(idx); + bool inserted = allow_set->insert(idx).second; if (VLOG_IS_ON(2) && inserted) { const NodeTypeId& item = *graph_type_view_.GetNode(idx); VLOG(2) << "Painting type " << item.type_attr.DebugString() << " of " << item.node->op() << " node " - << item.node->name() << " WHITE"; + << item.node->name() << " ALLOW"; } })); } } -void AutoMixedPrecisionImpl::PropagateWhiteThroughClear( +void AutoMixedPrecisionImpl::PropagateAllowThroughClear( const absl::flat_hash_set& black_set, - absl::flat_hash_set* white_set) const { - // Propagate white from white nodes through clearlist ops. + absl::flat_hash_set* allow_set) const { + // Propagate allow from allow nodes through clearlist ops. absl::flat_hash_set clear_prop_set; for (int root_idx = 0; root_idx < graph_type_view_.num_nodes(); ++root_idx) { const NodeTypeId& root = *graph_type_view_.GetNode(root_idx); if (!ShouldProcess(*root.node) || clear_prop_set.count(root_idx) || - !white_set->count(root_idx)) { + !allow_set->count(root_idx)) { continue; } DfsTypeTraversal( @@ -1661,7 +1661,7 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear( DfsTypePredicates::Enter([&](int idx) -> bool { const NodeTypeId& item = *graph_type_view_.GetNode(idx); return idx == root_idx || - (!white_set->count(idx) && !black_set.count(idx) && + (!allow_set->count(idx) && !black_set.count(idx) && ShouldProcess(*item.node) && IsFloat32(item) && SupportsF16(item) && (f16_clearlist_.count(item.node->op())) && @@ -1673,30 +1673,30 @@ void AutoMixedPrecisionImpl::PropagateWhiteThroughClear( }), DfsTypeCallbacks::PreOrder([&](int idx) { clear_prop_set.insert(idx); - bool inserted = white_set->insert(idx).second; + bool inserted = allow_set->insert(idx).second; if (VLOG_IS_ON(2) && inserted) { const NodeTypeId& item = *graph_type_view_.GetNode(idx); VLOG(2) << "Painting type " << item.type_attr.DebugString() << " of " << item.node->op() << " node " - << item.node->name() << " WHITE"; + << item.node->name() << " ALLOW"; } })); } } // Forces NextIteration nodes and their output Merge node(s) to have the same -// color. Specifically, it removes them all from white_set if any of the Merge -// nodes is not in white_set, otherwise it adds the NextIteration node to -// white_set. +// color. Specifically, it removes them all from allow_set if any of the Merge +// nodes is not in allow_set, otherwise it adds the NextIteration node to +// allow_set. Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges( - absl::flat_hash_set* white_set) const { + absl::flat_hash_set* allow_set) const { for (const NodeDef& node : graph_->node()) { if (node.op() == "NextIteration") { GraphView::OutputPort output_port(&node, 0); const auto& fanout = graph_view_.GetFanout(output_port); std::vector merge_idxs; merge_idxs.reserve(fanout.size()); - bool any_merge_is_not_white = false; + bool any_merge_is_not_allow = false; for (const auto& output : fanout) { const NodeDef& merge_node = *output.node; if (merge_node.op() != "Merge") { @@ -1712,8 +1712,8 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges( } int merge_idx = maybe_merge_idx.value(); merge_idxs.push_back(merge_idx); - any_merge_is_not_white = - any_merge_is_not_white || !white_set->count(merge_idx); + any_merge_is_not_allow = + any_merge_is_not_allow || !allow_set->count(merge_idx); } const absl::optional maybe_nextiter_idx = graph_type_view_.GetNodeIndex(node.name(), TypeAttrId("T")); @@ -1722,9 +1722,9 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges( node.name(), " not found in graph view"); } int nextiter_idx = maybe_nextiter_idx.value(); - if (any_merge_is_not_white) { + if (any_merge_is_not_allow) { for (int merge_idx : merge_idxs) { - if (white_set->erase(merge_idx)) { + if (allow_set->erase(merge_idx)) { VLOG(2) << "Painting type T of Merge node " << graph_type_view_.GetNode(merge_idx)->node->name() << " BLACK to match the color of its sibling Merge nodes " @@ -1732,14 +1732,14 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges( << node.name(); } } - if (white_set->erase(nextiter_idx)) { + if (allow_set->erase(nextiter_idx)) { VLOG(2) << "Painting type T of NextIteration node " << node.name() << " BLACK to match the color of its output Merge node(s)"; } } else { - if (white_set->insert(nextiter_idx).second) { + if (allow_set->insert(nextiter_idx).second) { VLOG(2) << "Painting type T of NextIteration node " << node.name() - << " WHITE to match the color of its output Merge node(s)"; + << " ALLOW to match the color of its output Merge node(s)"; } } } @@ -1750,10 +1750,10 @@ Status AutoMixedPrecisionImpl::ForceColorMatchOnRecurrentEdges( // Forces all of the given Tensor List nodes into the same color set. void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps( const absl::flat_hash_set& tensor_list_nodes, - absl::flat_hash_set* white_set, + absl::flat_hash_set* allow_set, absl::flat_hash_set* black_set) const { bool any_black = false; - bool any_white = false; + bool any_allow = false; std::vector node_type_idxs; node_type_idxs.reserve(tensor_list_nodes.size()); for (const NodeDef* node : tensor_list_nodes) { @@ -1769,23 +1769,23 @@ void AutoMixedPrecisionImpl::ForceColorMatchBetweenTensorListOps( if (black_set->count(node_type_idx)) { any_black = true; break; - } else if (white_set->count(node_type_idx)) { - any_white = true; + } else if (allow_set->count(node_type_idx)) { + any_allow = true; } } - if (!any_black && !any_white) return; + if (!any_black && !any_allow) return; for (int node_type_idx : node_type_idxs) { const NodeTypeId& node_type = *graph_type_view_.GetNode(node_type_idx); VLOG(2) << "Painting type " << node_type.type_attr.DebugString() << " of " << node_type.node->op() << " node " << node_type.node->name() << " " - << (any_black ? "BLACK" : "WHITE") + << (any_black ? "BLACK" : "ALLOW") << " because at least one of its siblings is " - << (any_black ? "BLACK" : "WHITE"); + << (any_black ? "BLACK" : "ALLOW"); if (any_black) { - white_set->erase(node_type_idx); + allow_set->erase(node_type_idx); black_set->insert(node_type_idx); } else { - white_set->insert(node_type_idx); + allow_set->insert(node_type_idx); } } } @@ -1807,10 +1807,10 @@ bool AutoMixedPrecisionImpl::NodeImplicitlyReadsNonResourceVariable( return false; } -// This adds existing Cast nodes to white_set if all of their outputs are white, +// This adds existing Cast nodes to allow_set if all of their outputs are allow, // avoiding the need to add a new Cast node after an existing Cast. -void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite( - absl::flat_hash_set* white_set) const { +void AutoMixedPrecisionImpl::MakeCastsAllowIfAllOutputsAllow( + absl::flat_hash_set* allow_set) const { int num_nodes_preop = graph_->node_size(); for (int node_idx = 0; node_idx < num_nodes_preop; ++node_idx) { NodeDef* node = graph_->mutable_node(node_idx); @@ -1818,7 +1818,7 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite( if (node->op() != "Cast" || !IsFloat32(node_type)) { continue; } - bool all_fanouts_white = true; + bool all_fanouts_allow = true; MutableGraphView::OutputPort src(node, 0); const auto& fanout = graph_view_.GetFanout(src); for (const MutableGraphView::InputPort& dst : fanout) { @@ -1830,13 +1830,13 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite( << "Type attribute " << dst_type_attr.DebugString() << " of node " << dst.node->name() << " not found in graph view"; int dst_type_idx = maybe_dst_type_idx.value(); - bool dst_is_white = white_set->count(dst_type_idx); - if (!dst_is_white) { - all_fanouts_white = false; + bool dst_is_allow = allow_set->count(dst_type_idx); + if (!dst_is_allow) { + all_fanouts_allow = false; break; } } - if (!fanout.empty() && all_fanouts_white) { + if (!fanout.empty() && all_fanouts_allow) { const absl::optional maybe_node_type_idx = graph_type_view_.GetNodeIndex(node_type); DCHECK(maybe_node_type_idx.has_value()) @@ -1844,16 +1844,16 @@ void AutoMixedPrecisionImpl::MakeCastsWhiteIfAllOutputsWhite( << " of node " << node_type.node->name() << " not found in graph view"; int node_type_idx = maybe_node_type_idx.value(); - white_set->insert(node_type_idx); + allow_set->insert(node_type_idx); } } } -// Changes all white-painted type attributes to DT_HALF or DT_BFLOAT16, and +// Changes all allow-painted type attributes to DT_HALF or DT_BFLOAT16, and // inserts Cast nodes at node outputs for all edges that connect -// white-painted <-> non-white-painted type attributes. +// allow-painted <-> non-allow-painted type attributes. Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts( - const absl::flat_hash_set& white_set) { + const absl::flat_hash_set& allow_set) { int num_nodes_changed = 0; int num_nonvar_casts_to_f16 = 0; int num_nodes_preop = graph_->node_size(); @@ -1869,8 +1869,8 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts( } int node_type_idx = maybe_node_type_idx.value(); if (!IsFloat32(*graph_type_view_.GetNode(node_type_idx))) continue; - bool src_is_white = white_set.count(node_type_idx); - if (src_is_white) { + bool src_is_allow = allow_set.count(node_type_idx); + if (src_is_allow) { VLOG(1) << "Changing type " << type_attr.DebugString() << " of " << node->op() << " node " << node->name() << " to " << DataTypeString(target_dtype_); @@ -1896,10 +1896,10 @@ Status AutoMixedPrecisionImpl::ChangeTypeAttrsAndAddCasts( " not found in graph view"); } int dst_type_idx = maybe_dst_type_idx.value(); - bool dst_is_white = white_set.count(dst_type_idx); - if (src_is_white != dst_is_white) { + bool dst_is_allow = allow_set.count(dst_type_idx); + if (src_is_allow != dst_is_allow) { if (!added_cast_node) { - bool to_f16 = dst_is_white; + bool to_f16 = dst_is_allow; VLOG(1) << "Inserting cast to " << (to_f16 ? DataTypeString(target_dtype_) : "DT_FLOAT") << " at " << src.node->op() << " " << src.node->name() diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h index a9840110d81..6643149a6e5 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_lists.h @@ -23,7 +23,7 @@ limitations under the License. namespace tensorflow { namespace grappler { -// Represents the four lists of ops: the white list, gray list, black list, and +// Represents the four lists of ops: the allow list, gray list, black list, and // clear list. These lists determine which ops are converted to fp16/bf16 // (referred to as 'f16' for short) and which ops stay as fp32. class AutoMixedPrecisionLists { @@ -33,7 +33,7 @@ class AutoMixedPrecisionLists { // Returns the set of ops that are considered numerically-safe (for execution // in f16), performance-critical, and can run in f16. These ops are always // converted to f16. - virtual gtl::FlatSet WhiteList() = 0; + virtual gtl::FlatSet AllowList() = 0; // Returns the set of ops that can run in f16 and are considered numerically- // safe (for execution in f16), but which may be made unsafe by an upstream // blacklist op. @@ -51,8 +51,10 @@ class AutoMixedPrecisionLists { protected: // Adds or removes ops from list if certain environmental variables are set. static void UpdateList(const string& list_name, gtl::FlatSet* list) { - CHECK(list_name == "WHITELIST" || list_name == "GRAYLIST" || // Crash OK. - list_name == "BLACKLIST" || list_name == "CLEARLIST"); + CHECK(list_name == "ALLOWLIST" || list_name == "GRAYLIST" || // Crash OK. + list_name == "BLACKLIST" || list_name == "CLEARLIST" || + // TODO(reedwm): for bkwds compat; remove when no longer necessary: + list_name == "WHITELIST"); string add_env_var = "TF_AUTO_MIXED_PRECISION_GRAPH_REWRITE_" + list_name + "_ADD"; string remove_env_var = @@ -104,7 +106,7 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { AutoMixedPrecisionListsCuda(int cuda_version, int cudnn_version) : cuda_version_(cuda_version), cudnn_version_(cudnn_version) {} - gtl::FlatSet WhiteList() override { + gtl::FlatSet AllowList() override { auto list = gtl::FlatSet{ "BlockLSTM", "BlockLSTMV2", @@ -144,7 +146,11 @@ class AutoMixedPrecisionListsCuda : public AutoMixedPrecisionLists { list.insert("Conv3DBackpropInput"); list.insert("Conv3DBackpropInputV2"); } + UpdateList("ALLOWLIST", &list); + // For backwards compatibility, keeping the original env variable here. + // TODO(reedwm): This should be removed if we don't have active users. UpdateList("WHITELIST", &list); + return list; } @@ -338,8 +344,8 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { AutoMixedPrecisionListsMkl() {} // Only ops which are supported by MKL in bfloat16 should be added to the - // white list, gray list, or clear list. - gtl::FlatSet WhiteList() override { + // allow list, gray list, or clear list. + gtl::FlatSet AllowList() override { auto list = gtl::FlatSet{"Conv2D", "Conv2DBackpropFilter", "Conv2DBackpropInput", @@ -353,7 +359,7 @@ class AutoMixedPrecisionListsMkl : public AutoMixedPrecisionLists { "BatchMatMul", "BatchMatMulV2"}; - UpdateList("WHITELIST", &list); + UpdateList("ALLOWLIST", &list); return list; } diff --git a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc index 5c18966c895..eef1f4c499a 100644 --- a/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc +++ b/tensorflow/core/grappler/optimizers/auto_mixed_precision_test.cc @@ -169,10 +169,10 @@ class AutoMixedPrecisionTest : public GrapplerTest { Output eye = ops::Const(s.WithOpName("eye"), GenerateIdentityMatrix(size, size)); Output input = ops::Placeholder(s.WithOpName("input"), DT_FLOAT); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, eye); - Output gry1 = test_op_factory(s.WithOpName("gry1"), wht1); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, eye); - Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, eye); + Output gry1 = test_op_factory(s.WithOpName("gry1"), allow1); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, eye); + Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2); GrapplerItem item; item.fetch = {"fetch1"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); @@ -190,9 +190,9 @@ class AutoMixedPrecisionTest : public GrapplerTest { GraphView output_view(&output); EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); auto tensors = EvaluateNodes(output, item.fetch, feed); EXPECT_EQ(tensors.size(), tensors_expected.size()); @@ -247,8 +247,8 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_HALF); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1); - Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1); + Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1); Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT); Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2); Output fetch = ops::Identity(s.WithOpName("fetch"), clr2); @@ -267,7 +267,7 @@ TEST_F(AutoMixedPrecisionTest, AlreadyFp16) { GraphView output_view(&output); EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT); @@ -288,8 +288,8 @@ TEST_F(AutoMixedPrecisionTest, Simple) { Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1); Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1); Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2); - Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2); + Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1); Output gry2 = ops::Log(s.WithOpName("gry2"), clr3); Output clr4 = ops::Relu(s.WithOpName("clr4"), gry2); Output blk2 = ops::SparseMatMul(s.WithOpName("blk2"), clr4, clr4); @@ -314,7 +314,7 @@ TEST_F(AutoMixedPrecisionTest, Simple) { EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry2")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT); @@ -335,10 +335,10 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); Output clr2 = ops::Relu(s.WithOpName("clr2"), input); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr1, clr1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1); auto clr3 = ops::ShapeN(s.WithOpName("clr3"), {clr1, clr2}); Output clr4 = ops::Relu(s.WithOpName("clr4"), clr2); - Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1); + Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1); Output fetch2 = ops::Identity(s.WithOpName("fetch2"), clr4); GrapplerItem item; @@ -357,7 +357,7 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_HALF); @@ -372,18 +372,18 @@ TEST_F(AutoMixedPrecisionTest, BidirectionalClearChain) { TEST_F(AutoMixedPrecisionTest, PreserveFetches) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); - Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); + Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1); Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1); Output blk1 = ops::Exp(s.WithOpName("blk1"), gry1); Output clr2 = ops::Relu(s.WithOpName("clr2"), blk1); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), clr2, clr2); - Output clr3 = ops::Relu(s.WithOpName("clr3"), wht2); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), clr2, clr2); + Output clr3 = ops::Relu(s.WithOpName("clr3"), allow2); Output blk2 = ops::Exp(s.WithOpName("blk2"), clr3); Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2); GrapplerItem item; - item.fetch = {"wht1", "clr2", "clr3"}; + item.fetch = {"allow1", "clr2", "clr3"}; TF_CHECK_OK(s.ToGraphDef(&item.graph)); auto tensors_expected = EvaluateNodes(item.graph, item.fetch); @@ -396,12 +396,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveFetches) { GraphView output_view(&output); EXPECT_EQ(output.node_size(), item.graph.node_size() + 2); EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_FLOAT); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT); @@ -418,12 +418,13 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output clr1 = ops::Relu(s.WithOpName("clr1"), input); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr1, clr1); - Output gry1 = ops::Tanh(s.WithOpName("gry1"), wht1); - Output wht2 = ops::MatMul(s.WithOpName("wht2").WithDevice( - "/job:localhost/replica:0/task:0/device:CPU:0"), - gry1, gry1); - Output clr2 = ops::Relu(s.WithOpName("clr2"), wht2); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr1, clr1); + Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1); + Output allow2 = + ops::MatMul(s.WithOpName("allow2").WithDevice( + "/job:localhost/replica:0/task:0/device:CPU:0"), + gry1, gry1); + Output clr2 = ops::Relu(s.WithOpName("clr2"), allow2); Output fetch = ops::Identity(s.WithOpName("fetch"), clr2); GrapplerItem item; @@ -441,9 +442,9 @@ TEST_F(AutoMixedPrecisionTest, PreserveCPUNodes) { EXPECT_EQ(output.node_size(), item.graph.node_size() + 2); EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_FLOAT); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_FLOAT); auto tensors = EvaluateNodes(output, item.fetch); @@ -459,12 +460,12 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) { Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); Output var1 = ops::Variable(s.WithOpName("var1"), {32, 32}, DT_FLOAT); Output clr1 = ops::Identity(s.WithOpName("clr1"), var1); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, clr1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, clr1); Output input2 = ops::Const(s.WithOpName("input2"), 1.f / 32, {32, 32}); Output clr2 = ops::Identity(s.WithOpName("clr2"), input2); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), input, clr2); - Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1); - Output fetch2 = ops::Identity(s.WithOpName("fetch2"), wht2); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), input, clr2); + Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1); + Output fetch2 = ops::Identity(s.WithOpName("fetch2"), allow2); GrapplerItem item; item.fetch = {"fetch1", "fetch2"}; @@ -485,10 +486,10 @@ TEST_F(AutoMixedPrecisionTest, PreserveIdentityAfterVariable) { EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("var1")->attr().at("dtype").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("input2")->attr().at("dtype").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); auto tensors = EvaluateNodes(output, item.fetch, feed); EXPECT_EQ(tensors.size(), tensors_expected.size()); @@ -507,22 +508,24 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { Output offset = ops::Const(s.WithOpName("offset"), 4.f, {16}); Output mean = ops::Const(s.WithOpName("mean"), 5.f, {0}); Output variance = ops::Const(s.WithOpName("variance"), 6.f, {0}); - Output wht1 = ops::Conv2D(s.WithOpName("wht1"), input, weight, {1, 1, 1, 1}, - "SAME", ops::Conv2D::DataFormat("NHWC")); + Output allow1 = + ops::Conv2D(s.WithOpName("allow1"), input, weight, {1, 1, 1, 1}, "SAME", + ops::Conv2D::DataFormat("NHWC")); auto fbn1_op = - ops::FusedBatchNorm(s.WithOpName("fbn1"), wht1, scale, offset, mean, + ops::FusedBatchNorm(s.WithOpName("fbn1"), allow1, scale, offset, mean, variance, ops::FusedBatchNorm::DataFormat("NHWC")); Output fbn1 = fbn1_op.y; Output fbn1_rs1 = fbn1_op.reserve_space_1; Output fbn1_rs2 = fbn1_op.reserve_space_2; Output bng1 = ops::FusedBatchNormGrad( - s.WithOpName("bng1"), fbn1, wht1, scale, fbn1_rs1, fbn1_rs2, - ops::FusedBatchNormGrad::DataFormat("NHWC")) + s.WithOpName("bng1"), fbn1, allow1, scale, fbn1_rs1, + fbn1_rs2, ops::FusedBatchNormGrad::DataFormat("NHWC")) .x_backprop; Output gry1 = ops::Add(s.WithOpName("gry1"), fbn1, bng1); - Output wht2 = ops::Conv2D(s.WithOpName("wht2"), gry1, weight, {1, 1, 1, 1}, - "SAME", ops::Conv2D::DataFormat("NHWC")); - Output fetch = ops::Identity(s.WithOpName("fetch"), wht2); + Output allow2 = + ops::Conv2D(s.WithOpName("allow2"), gry1, weight, {1, 1, 1, 1}, "SAME", + ops::Conv2D::DataFormat("NHWC")); + Output fetch = ops::Identity(s.WithOpName("fetch"), allow2); GrapplerItem item; item.fetch = {"fetch"}; @@ -537,7 +540,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { GraphView output_view(&output); EXPECT_EQ(output.node_size(), item.graph.node_size() + 3); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("fbn1")->op(), "FusedBatchNormV2"); EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("fbn1")->attr().at("U").type(), DT_FLOAT); @@ -545,7 +548,7 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { EXPECT_EQ(output_view.GetNode("bng1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("bng1")->attr().at("U").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(tensors.size(), tensors_expected.size()); @@ -558,13 +561,13 @@ TEST_F(AutoMixedPrecisionTest, FusedBatchNorm) { TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); - auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {wht1, wht1, wht1}); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); + auto clr1_op = ops::IdentityN(s.WithOpName("clr1"), {allow1, allow1, allow1}); Output gry1 = ops::AddN(s.WithOpName("gry1"), {clr1_op.output[0], clr1_op.output[1], clr1_op.output[2]}); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1); - Output fetch = ops::Identity(s.WithOpName("fetch"), wht2); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1); + Output fetch = ops::Identity(s.WithOpName("fetch"), allow2); GrapplerItem item; item.fetch = {"fetch"}; @@ -580,12 +583,12 @@ TEST_F(AutoMixedPrecisionTest, RepeatedAndListTypeAttrs) { GraphView output_view(&output); EXPECT_EQ(output.node_size(), item.graph.node_size() + 2); EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); for (auto type : output_view.GetNode("clr1")->attr().at("T").list().type()) { EXPECT_EQ(type, DT_HALF); } EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(tensors.size(), tensors_expected.size()); @@ -599,8 +602,8 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), true, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_FLOAT); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1); - Output fetch = ops::Identity(s.WithOpName("fetch"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1); + Output fetch = ops::Identity(s.WithOpName("fetch"), allow1); GrapplerItem item; item.fetch = {"fetch"}; @@ -617,7 +620,7 @@ TEST_F(AutoMixedPrecisionTest, ExistingCast) { EXPECT_EQ(output.node_size(), item.graph.node_size() + 1); EXPECT_EQ(output_view.GetNode("cst1")->attr().at("SrcT").type(), DT_BOOL); EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); auto tensors = EvaluateNodes(output, item.fetch); EXPECT_EQ(tensors.size(), tensors_expected.size()); @@ -640,8 +643,8 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) { Output lpc1 = ops::LoopCond(s.WithOpName("lpc1"), con1).output; auto swt1 = ops::Switch(s.WithOpName("swt1"), mrg1, lpc1); Output gry1 = ops::Sqrt(s.WithOpName("gry1"), swt1.output_true); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), gry1, gry1); - Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), gry1, gry1); + Output nxt1 = ops::NextIteration(s.WithOpName("nxt1"), allow1); Output ext1 = ops::internal::Exit(s.WithOpName("ext1"), swt1.output_false); Output fetch = ops::Identity(s.WithOpName("fetch"), ext1); // Add a second merge node from the same NextIteration node. This case arises @@ -670,13 +673,13 @@ TEST_F(AutoMixedPrecisionTest, RecurrentEdgeColorMismatch) { EXPECT_EQ(output.node_size(), item.graph.node_size() + 2); // Note that mrg1 gets painted black because it is between blk1 and gry1. This // forces nxt1 and mrg2 to be painted black as well (they would otherwise be - // painted white because they are clear and have a direct path to wht1). + // painted allow because they are clear and have a direct path to allow1). EXPECT_EQ(output_view.GetNode("blk1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("ent1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("mrg1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("swt1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("nxt1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("ext1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("mrg2")->attr().at("T").type(), DT_FLOAT); @@ -699,9 +702,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { Output idx3 = ops::Const(s.WithOpName("idx3"), 3); auto tl1w1 = ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); auto tl1w2 = - ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1); + ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1); // Ensure that TensorListResize doesn't cause any problems. Output tl1rs = ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6); @@ -709,9 +712,9 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { shape, DT_FLOAT) .item; Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1); auto tl1w3 = - ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2); + ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2); Output tl1r2 = ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3, shape, DT_FLOAT) @@ -742,11 +745,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListSetGet) { const char* type_key = "element_dtype"; EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT); @@ -767,15 +770,16 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) { Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, input); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); - auto tl1w2 = - ops::TensorListPushBack(s.WithOpName("tl1w2"), tl1w1.output_handle, wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); + auto tl1w2 = ops::TensorListPushBack(s.WithOpName("tl1w2"), + tl1w1.output_handle, allow1); Output tl1r1 = ops::TensorListPopBack(s.WithOpName("tl1r1"), tl1w2.output_handle, shape, DT_FLOAT) .tensor; Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1); - auto tl1w3 = ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, wht2); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1); + auto tl1w3 = + ops::TensorListPushBack(s.WithOpName("tl1w3"), tl1.handle, allow2); Output tl1r2 = ops::TensorListPopBack(s.WithOpName("tl1r2"), tl1w3.output_handle, shape, DT_FLOAT) .tensor; @@ -804,11 +808,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushPop) { const char* type_key = "element_dtype"; EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_FLOAT); @@ -826,19 +830,19 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); tensorflow::Input shape = {32}; Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); - auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), wht1, shape); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); + auto tl1 = ops::TensorListFromTensor(s.WithOpName("tl1"), allow1, shape); Output tl1r1 = ops::TensorListStack(s.WithOpName("tl1r1"), tl1.output_handle, shape, DT_FLOAT) .tensor; Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl1r1); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1); - Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1); + Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2); - // This tests that a white-painted object node (tl2) will force an unpainted - // client node (tl2w1) to be painted white as well. (Without the force, tl2w1 + // This tests that a allow-painted object node (tl2) will force an unpainted + // client node (tl2w1) to be painted allow as well. (Without the force, tl2w1 // would remain unpainted, producing an invalid graph). - auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), wht1, shape); + auto tl2 = ops::TensorListFromTensor(s.WithOpName("tl2"), allow1, shape); auto tl2w1 = ops::TensorListPushBack(s.WithOpName("tl2w1"), tl2.output_handle, input); @@ -856,11 +860,11 @@ TEST_F(AutoMixedPrecisionTest, TensorListFromTensor) { GraphView output_view(&output); EXPECT_EQ(output.node_size(), item.graph.node_size() + 2); const char* type_key = "element_dtype"; - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF); @@ -878,12 +882,13 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); auto tl2 = ops::EmptyTensorList(s.WithOpName("tl2"), {32, 32}, 8, DT_FLOAT); Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); Output tl1_tl2 = ops::Stack(s.WithOpName("tl1_tl2"), {tl1.handle, tl2.handle}); - Output wht1_wht1 = ops::Stack(s.WithOpName("wht1_wht1"), {wht1, wht1}); - auto tl12w1 = - ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2, wht1_wht1); + Output allow1_allow1 = + ops::Stack(s.WithOpName("allow1_allow1"), {allow1, allow1}); + auto tl12w1 = ops::TensorListPushBackBatch(s.WithOpName("tl12w1"), tl1_tl2, + allow1_allow1); OutputList tl12w1_outputs = ops::Split(s.WithOpName("tl12w1_outputs"), 0, tl12w1.output_handles, 2) .output; @@ -898,8 +903,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { ops::TensorListPopBack(s.WithOpName("tl3r1"), tl3, shape, DT_FLOAT) .tensor; Output gry1 = ops::Tanh(s.WithOpName("gry1"), tl3r1); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1); - Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1); + Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2); GrapplerItem item; item.fetch = {"fetch1"}; @@ -915,8 +920,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListPushBackBatchAndConcatLists) { GraphView output_view(&output); EXPECT_EQ(output.node_size(), item.graph.node_size() + 2); const char* type_key = "element_dtype"; - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl1")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF); @@ -961,8 +966,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) { TF_CHECK_OK(s.graph()->AddFunctionLibrary(function_lib)); tensorflow::Input shape = {32, 32}; Output input = ops::Const(s.WithOpName("input"), 1.f / 32, {32, 32}); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); - Output gry1 = ops::Tanh(s.WithOpName("gry1"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); + Output gry1 = ops::Tanh(s.WithOpName("gry1"), allow1); auto tl1 = ops::EmptyTensorList(s.WithOpName("tl1"), {32, 32}, 8, DT_FLOAT); auto tl1w1 = ops::TensorListPushBack(s.WithOpName("tl1w1"), tl1.handle, gry1); auto _gry1 = tensorflow::ops::AsNodeOut(s, gry1); @@ -981,8 +986,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) { Output tl2r1 = ops::TensorListPopBack(s.WithOpName("tl2r1"), tl2w1.output_handle, shape, DT_FLOAT) .tensor; - Output wht2 = ops::MatMul(s.WithOpName("wht2"), tl1r1, tl2r1); - Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht2); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), tl1r1, tl2r1); + Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow2); GrapplerItem item; item.fetch = {"fetch1"}; @@ -997,8 +1002,8 @@ TEST_F(AutoMixedPrecisionTest, TensorListThroughFunction) { GraphView output_view(&output); const char* type_key = "element_dtype"; - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_HALF); EXPECT_EQ(output_view.GetNode("tl2w1")->attr().at(type_key).type(), DT_HALF); @@ -1031,8 +1036,8 @@ int GetCudaVersion(const Cluster& cluster) { TEST_F(AutoMixedPrecisionTest, BatchMatMul) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f / 33, {64, 32, 32}); - Output wht1 = ops::BatchMatMul(s.WithOpName("wht1"), input, input); - Output fetch1 = ops::Identity(s.WithOpName("fetch1"), wht1); + Output allow1 = ops::BatchMatMul(s.WithOpName("allow1"), input, input); + Output fetch1 = ops::Identity(s.WithOpName("fetch1"), allow1); GrapplerItem item; item.fetch = {"fetch1"}; @@ -1049,10 +1054,10 @@ TEST_F(AutoMixedPrecisionTest, BatchMatMul) { EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); if (GetCudaVersion(*virtual_cluster_.get()) >= 9010) { EXPECT_EQ(output.node_size(), item.graph.node_size() + 2); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_HALF); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_HALF); } else { EXPECT_EQ(output.node_size(), item.graph.node_size()); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_FLOAT); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_FLOAT); } auto tensors = EvaluateNodes(output, item.fetch); @@ -1187,8 +1192,8 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) { tensorflow::Scope s = tensorflow::Scope::NewRootScope(); Output input = ops::Const(s.WithOpName("input"), 1.f, {32, 32}); Output cst1 = ops::Cast(s.WithOpName("cst1"), input, DT_BFLOAT16); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), cst1, cst1); - Output clr1 = ops::Relu(s.WithOpName("clr1"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), cst1, cst1); + Output clr1 = ops::Relu(s.WithOpName("clr1"), allow1); Output cst2 = ops::Cast(s.WithOpName("cst2"), clr1, DT_FLOAT); Output clr2 = ops::Relu(s.WithOpName("clr2"), cst2); Output fetch = ops::Identity(s.WithOpName("fetch"), clr2); @@ -1207,7 +1212,7 @@ TEST_F(AutoMixedPrecisionMklTest, AlreadyBf16) { GraphView output_view(&output); EXPECT_EQ(output_view.GetNode("input")->attr().at("dtype").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("cst1")->attr().at("DstT").type(), DT_BFLOAT16); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("cst2")->attr().at("SrcT").type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("cst2")->attr().at("DstT").type(), DT_FLOAT); @@ -1228,8 +1233,8 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) { Output clr1 = ops::Relu(s.WithOpName("clr1"), blk1); Output gry1 = ops::Sqrt(s.WithOpName("gry1"), clr1); Output clr2 = ops::Relu(s.WithOpName("clr2"), gry1); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), clr2, clr2); - Output clr3 = ops::Relu(s.WithOpName("clr3"), wht1); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), clr2, clr2); + Output clr3 = ops::Relu(s.WithOpName("clr3"), allow1); Output blk2 = ops::Log(s.WithOpName("blk2"), clr3); Output clr4 = ops::Relu(s.WithOpName("clr4"), blk2); Output blk3 = ops::SparseMatMul(s.WithOpName("blk3"), clr4, clr4); @@ -1254,7 +1259,7 @@ TEST_F(AutoMixedPrecisionMklTest, Simple) { EXPECT_EQ(output_view.GetNode("clr1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr2")->attr().at("T").type(), DT_BFLOAT16); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("clr3")->attr().at("T").type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("blk2")->attr().at("T").type(), DT_FLOAT); EXPECT_EQ(output_view.GetNode("clr4")->attr().at("T").type(), DT_FLOAT); @@ -1280,9 +1285,9 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) { Output idx3 = ops::Const(s.WithOpName("idx3"), 3); auto tl1w1 = ops::TensorListSetItem(s.WithOpName("tl1w1"), tl1.handle, idx1, input); - Output wht1 = ops::MatMul(s.WithOpName("wht1"), input, input); + Output allow1 = ops::MatMul(s.WithOpName("allow1"), input, input); auto tl1w2 = - ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, wht1); + ops::TensorListSetItem(s.WithOpName("tl1w2"), tl1.handle, idx2, allow1); // Ensure that TensorListResize doesn't cause any problems. Output tl1rs = ops::TensorListResize(s.WithOpName("tl1rs"), tl1w2.output_handle, 6); @@ -1290,9 +1295,9 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) { shape, DT_FLOAT) .item; Output gry1 = ops::Mul(s.WithOpName("gry1"), tl1r1, tl1r1); - Output wht2 = ops::MatMul(s.WithOpName("wht2"), gry1, gry1); + Output allow2 = ops::MatMul(s.WithOpName("allow2"), gry1, gry1); auto tl1w3 = - ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, wht2); + ops::TensorListSetItem(s.WithOpName("tl1w3"), tl1.handle, idx3, allow2); Output tl1r2 = ops::TensorListGetItem(s.WithOpName("tl1r2"), tl1w3.output_handle, idx3, shape, DT_FLOAT) @@ -1325,13 +1330,13 @@ TEST_F(AutoMixedPrecisionMklTest, TensorListSetGet) { DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("tl1w1")->attr().at(type_key).type(), DT_BFLOAT16); - EXPECT_EQ(output_view.GetNode("wht1")->attr().at("T").type(), DT_BFLOAT16); + EXPECT_EQ(output_view.GetNode("allow1")->attr().at("T").type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("tl1w2")->attr().at(type_key).type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("tl1r1")->attr().at(type_key).type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("gry1")->attr().at("T").type(), DT_BFLOAT16); - EXPECT_EQ(output_view.GetNode("wht2")->attr().at("T").type(), DT_BFLOAT16); + EXPECT_EQ(output_view.GetNode("allow2")->attr().at("T").type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("tl1w3")->attr().at(type_key).type(), DT_BFLOAT16); EXPECT_EQ(output_view.GetNode("tl2")->attr().at(type_key).type(), DT_FLOAT); diff --git a/tensorflow/core/grappler/optimizers/constant_folding.cc b/tensorflow/core/grappler/optimizers/constant_folding.cc index 14effd929de..bcb8ad37d6c 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.cc +++ b/tensorflow/core/grappler/optimizers/constant_folding.cc @@ -1020,9 +1020,9 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node, return false; } - // Skips nodes that must be preserved except whitelisted nodes. + // Skips nodes that must be preserved except allowlisted nodes. if (nodes_to_preserve_.find(node.name()) != nodes_to_preserve_.end() && - nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) { + nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) { return false; } @@ -1082,13 +1082,13 @@ bool ConstantFolding::MaybeFoldable(const NodeDef& node, } } - // Don't fold nodes that have no outgoing edges except whitelisted nodes. + // Don't fold nodes that have no outgoing edges except allowlisted nodes. // Such nodes could be introduced by an earlier constant folding pass and are // preserved in case users want to fetch their values; re-processing them // would lead to an error of adding a duplicated node to graph. const auto& outputs = node_map_->GetOutputs(node.name()); if (outputs.empty() && - nodes_whitelist_.find(node.name()) == nodes_whitelist_.end()) { + nodes_allowlist_.find(node.name()) == nodes_allowlist_.end()) { return false; } return true; @@ -3874,7 +3874,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, GraphDef* optimized_graph) { graph_ = &item->graph; node_map_.reset(new NodeMap(graph_)); - nodes_whitelist_.clear(); + nodes_allowlist_.clear(); // Fold fetch nodes iff it has a single fanout. Note that if a fetch node // has a single fanout, it would be rewritten as a constant with the same // node name, and therefore users are still able to fetch it. This is not @@ -3885,7 +3885,7 @@ Status ConstantFolding::RunOptimizationPass(Cluster* cluster, for (const auto& fetch : item->fetch) { const NodeDef* fetch_node = node_map_->GetNode(fetch); if (fetch_node && NumOutputs(*fetch_node, graph_) == 1) { - nodes_whitelist_.insert(fetch_node->name()); + nodes_allowlist_.insert(fetch_node->name()); } } diff --git a/tensorflow/core/grappler/optimizers/constant_folding.h b/tensorflow/core/grappler/optimizers/constant_folding.h index 0f4d91e4315..4e3deb40d15 100644 --- a/tensorflow/core/grappler/optimizers/constant_folding.h +++ b/tensorflow/core/grappler/optimizers/constant_folding.h @@ -328,7 +328,7 @@ class ConstantFolding : public GraphOptimizer { std::unique_ptr node_map_; std::unordered_set nodes_to_preserve_; // TODO(rmlarsen): Could these be keyed on absl::string_view? - absl::flat_hash_set nodes_whitelist_; + absl::flat_hash_set nodes_allowlist_; absl::flat_hash_set feed_nodes_; absl::flat_hash_map maybe_foldable_nodes_; bool has_fetch_; diff --git a/tensorflow/core/kernels/data/captured_function.cc b/tensorflow/core/kernels/data/captured_function.cc index 92409af8f61..aac07eebfa1 100644 --- a/tensorflow/core/kernels/data/captured_function.cc +++ b/tensorflow/core/kernels/data/captured_function.cc @@ -232,16 +232,16 @@ Status IsFunctionStateful(const FunctionLibraryDefinition& library, return Status::OK(); } -// Returns whether an op has been whitelisted as stateless. Uses a heuristic to -// whitelist source dataset ops which have been marked stateful due to +// Returns whether an op has been allowlisted as stateless. Uses a heuristic to +// allowlist source dataset ops which have been marked stateful due to // b/65524810. Also looks up the `op_def->name` in the global -// `WhitelistedStatefulOpRegistry`. -bool IsOpWhitelisted(const OpDef* op_def) { +// `AllowlistedStatefulOpRegistry`. +bool IsOpAllowlisted(const OpDef* op_def) { return (op_def->output_arg_size() == 1 && op_def->output_arg(0).type() == DT_VARIANT && (absl::EndsWith(op_def->name(), "Dataset") || absl::EndsWith(op_def->name(), "DatasetV2"))) || - WhitelistedStatefulOpRegistry::Global()->Contains(op_def->name()); + AllowlistedStatefulOpRegistry::Global()->Contains(op_def->name()); } Status LookupFunction(const FunctionLibraryDefinition& lib_def, @@ -389,7 +389,7 @@ Status IsNodeStateful(const FunctionLibraryDefinition& library, // TODO(jsimsa): Fix C++ unit tests so that we do not have to ignore // `LookUpOpDef` errors here. if (!OpRegistry::Global()->LookUpOpDef(node.op(), &op_def).ok() || - IsOpWhitelisted(op_def) || !op_def->is_stateful() || + IsOpAllowlisted(op_def) || !op_def->is_stateful() || op_def->name() == "Assert") { return Status::OK(); } diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index 843ef2fb7e1..b44b7965040 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -41478,13 +41478,13 @@ func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_ke // DatasetToGraphAttr is an optional argument to DatasetToGraph. type DatasetToGraphAttr func(optionalAttr) -// DatasetToGraphStatefulWhitelist sets the optional stateful_whitelist attribute to value. +// DatasetToGraphStatefulAllowlist sets the optional stateful_allowlist attribute to value. // If not specified, defaults to <> // // REQUIRES: len(value) >= 0 -func DatasetToGraphStatefulWhitelist(value []string) DatasetToGraphAttr { +func DatasetToGraphStatefulAllowlist(value []string) DatasetToGraphAttr { return func(m optionalAttr) { - m["stateful_whitelist"] = value + m["stateful_allowlist"] = value } } diff --git a/tensorflow/lite/delegates/flex/BUILD b/tensorflow/lite/delegates/flex/BUILD index 99bcf05ab4a..2ee6109521b 100644 --- a/tensorflow/lite/delegates/flex/BUILD +++ b/tensorflow/lite/delegates/flex/BUILD @@ -233,10 +233,10 @@ tf_cc_test( cc_library( name = "whitelisted_flex_ops_lib", srcs = [ - "whitelisted_flex_ops.cc", + "allowlisted_flex_ops.cc", ], hdrs = [ - "whitelisted_flex_ops.h", + "allowlisted_flex_ops.h", ], ) diff --git a/tensorflow/lite/delegates/flex/whitelisted_flex_ops.cc b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc similarity index 98% rename from tensorflow/lite/delegates/flex/whitelisted_flex_ops.cc rename to tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc index a4e11e54905..885601e5333 100644 --- a/tensorflow/lite/delegates/flex/whitelisted_flex_ops.cc +++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.cc @@ -12,15 +12,15 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" - #include +#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" + namespace tflite { namespace flex { -bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) { - static const std::set* whitelisted_flex_ops = +bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name) { + static const std::set* allowlisted_flex_ops = new std::set({ // go/keep-sorted start "Abort", @@ -538,8 +538,8 @@ bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name) { "_Send", // go/keep-sorted end }); - return whitelisted_flex_ops->find(tensorflow_op_name) != - whitelisted_flex_ops->end(); + return allowlisted_flex_ops->find(tensorflow_op_name) != + allowlisted_flex_ops->end(); // Prevent lint error about this function being too long. This function // is a set of ops, and making it shorter won't help readbility. // NOLINTNEXTLINE diff --git a/tensorflow/lite/delegates/flex/whitelisted_flex_ops.h b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.h similarity index 65% rename from tensorflow/lite/delegates/flex/whitelisted_flex_ops.h rename to tensorflow/lite/delegates/flex/allowlisted_flex_ops.h index 189a6940536..46b7068de25 100644 --- a/tensorflow/lite/delegates/flex/whitelisted_flex_ops.h +++ b/tensorflow/lite/delegates/flex/allowlisted_flex_ops.h @@ -12,24 +12,24 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ -#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_ -#define TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_ +#ifndef TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_ +#define TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_ #include namespace tflite { namespace flex { -// Whether the given op has been statically whitelisted for flex export. +// Whether the given op has been statically allowlisted for flex export. // -// This static whitelist is formed by the intersection of ops supported by +// This static allowlist is formed by the intersection of ops supported by // TensorFlowMobile on both iOS and Android. As the converter is likely running // on a host that has the full suite of TensorFlow ops available, we use this -// static whitelist to ensure compatibility when deploying to a mobile device. -// TODO(b/118389105): Automate generation of the whitelisted flex ops. -bool IsWhitelistedFlexOp(const std::string& tensorflow_op_name); +// static allowlist to ensure compatibility when deploying to a mobile device. +// TODO(b/118389105): Automate generation of the allowlisted flex ops. +bool IsAllowlistedFlexOp(const std::string& tensorflow_op_name); } // namespace flex } // namespace tflite -#endif // TENSORFLOW_LITE_DELEGATES_FLEX_WHITELISTED_FLEX_OPS_H_ +#endif // TENSORFLOW_LITE_DELEGATES_FLEX_ALLOWLISTED_FLEX_OPS_H_ diff --git a/tensorflow/lite/delegates/hexagon/utils.cc b/tensorflow/lite/delegates/hexagon/utils.cc index 223d4a8a826..14d651a9d7d 100644 --- a/tensorflow/lite/delegates/hexagon/utils.cc +++ b/tensorflow/lite/delegates/hexagon/utils.cc @@ -70,7 +70,7 @@ TfLiteStatus Get4DShape(unsigned int* batch_size, unsigned int* height_size, return kTfLiteOk; } -// We maintain an op-version whitelist here to ensure we don't accept unintended +// We maintain an op-version allowlist here to ensure we don't accept unintended // ops. bool CheckOpVersion(const TfLiteRegistration* registration) { switch (registration->builtin_code) { diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc index aff1ca23ee5..14c52f52bd6 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_list.cc @@ -18,7 +18,7 @@ namespace tflite { const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig = R"( -## Every Test can be whitelisted or blacklisted using a regexp on its test_id +## Every Test can be allowlisted or blacklisted using a regexp on its test_id ## Test_id # @@ -28,7 +28,7 @@ const constexpr char* NnapiAccelerationTestParams::kAccelerationTestConfig = # the ordinal is the position in the list of parameters generated by the # cardinal product of all the different parameter sets -# Blacklist/Whitelist +# Blacklist/Allowlist # To blacklist an element simply add - before the test_id regex ## Rules evaluation diff --git a/tensorflow/lite/delegates/nnapi/acceleration_test_util.h b/tensorflow/lite/delegates/nnapi/acceleration_test_util.h index e99225e175f..042cb416e6b 100644 --- a/tensorflow/lite/delegates/nnapi/acceleration_test_util.h +++ b/tensorflow/lite/delegates/nnapi/acceleration_test_util.h @@ -21,7 +21,7 @@ limitations under the License. namespace tflite { -// NNAPI specific configuration for the validation whitelist. +// NNAPI specific configuration for the validation allowlist. class NnapiAccelerationTestParams { public: // Content in nnapi_acceleration_test_list.cc. diff --git a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc index cb95d7bd248..ce55d671b5d 100644 --- a/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc +++ b/tensorflow/lite/delegates/nnapi/nnapi_delegate.cc @@ -4526,7 +4526,7 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context, } else { // If no accelerator is specified, only use NNAPI if an accelerator is // available. Any available accelerator will make the device_count larger - // than 1. More sophisticated check and whitelisting can be added later. + // than 1. More sophisticated check and allowlisting can be added later. uint32_t device_count = 0; RETURN_TFLITE_ERROR_IF_NN_ERROR( context, nnapi->ANeuralNetworks_getDeviceCount(&device_count), diff --git a/tensorflow/lite/experimental/acceleration/README.md b/tensorflow/lite/experimental/acceleration/README.md index c3209fe99e9..bd07b4f0b2b 100644 --- a/tensorflow/lite/experimental/acceleration/README.md +++ b/tensorflow/lite/experimental/acceleration/README.md @@ -1,4 +1,4 @@ -# Accelerator whitelisting +# Accelerator allowlisting Experimental library and tools for determining whether an accelerator engine works well on a given device, and for a given model. @@ -6,7 +6,7 @@ works well on a given device, and for a given model. ## Platform-agnostic, Android-first Android-focused, since the much smaller set of configurations on iOS means there -is much less need for whitelisting on iOS. +is much less need for allowlisting on iOS. ## Not just for TfLite diff --git a/tensorflow/lite/experimental/acceleration/configuration/configuration.proto b/tensorflow/lite/experimental/acceleration/configuration/configuration.proto index 44d462da073..3091eec6d46 100644 --- a/tensorflow/lite/experimental/acceleration/configuration/configuration.proto +++ b/tensorflow/lite/experimental/acceleration/configuration/configuration.proto @@ -32,7 +32,7 @@ package tflite.proto; // compatibility list entries have been developed for and what settings are used // for NNAPI. enum ExecutionPreference { - // Match any selected preference. Whitelist (semantically - value is same as + // Match any selected preference. Allowlist (semantically - value is same as // on input). ANY = 0; // Match low latency preference. Both compatibility list and input. diff --git a/tensorflow/lite/g3doc/guide/ops_select.md b/tensorflow/lite/g3doc/guide/ops_select.md index dac0cadcb9c..5aa3e96cae2 100644 --- a/tensorflow/lite/g3doc/guide/ops_select.md +++ b/tensorflow/lite/g3doc/guide/ops_select.md @@ -39,8 +39,8 @@ for `target_spec.supported_ops`: * `TFLITE_BUILTINS` - Converts models using TensorFlow Lite builtin ops. * `SELECT_TF_OPS` - Converts models using TensorFlow ops. The exact subset of - supported ops can be found in the whitelist at - `lite/delegates/flex/whitelisted_flex_ops.cc`. + supported ops can be found in the allowlist at + `lite/delegates/flex/allowlisted_flex_ops.cc`. Note: `target_spec.supported_ops` was previously `target_ops` in the Python API. diff --git a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/CompatibilityListTest.java b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/CompatibilityListTest.java index 5693c7a74d7..b04189cbfcf 100644 --- a/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/CompatibilityListTest.java +++ b/tensorflow/lite/java/src/test/java/org/tensorflow/lite/gpu/CompatibilityListTest.java @@ -27,8 +27,8 @@ public final class CompatibilityListTest { @Test public void testBasic() throws Exception { - try (CompatibilityList whitelist = new CompatibilityList()) { - assertThat(whitelist.isDelegateSupportedOnThisDevice()).isTrue(); + try (CompatibilityList allowlist = new CompatibilityList()) { + assertThat(allowlist.isDelegateSupportedOnThisDevice()).isTrue(); } } } diff --git a/tensorflow/lite/kernels/acceleration_test_util.h b/tensorflow/lite/kernels/acceleration_test_util.h index a6a88d5f131..78e4d01a44d 100644 --- a/tensorflow/lite/kernels/acceleration_test_util.h +++ b/tensorflow/lite/kernels/acceleration_test_util.h @@ -20,7 +20,7 @@ limitations under the License. namespace tflite { // Returns the test id to use to retrieve the acceleration configuration -// in the acceleration whitelist. +// in the acceleration allowlist. std::string GetCurrentTestId(); } // namespace tflite diff --git a/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc b/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc index 82d21fd9332..6d6b7a722b8 100644 --- a/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc +++ b/tensorflow/lite/kernels/acceleration_test_util_internal_test.cc @@ -51,14 +51,14 @@ struct SimpleConfig { class ReadAccelerationConfigTest : public ::testing::Test { public: - std::unordered_map whitelist_; + std::unordered_map allowlist_; std::unordered_map blacklist_; std::function consumer_ = [this](std::string key, std::string value, bool is_blacklist) { if (is_blacklist) { blacklist_[key] = {value}; } else { - whitelist_[key] = {value}; + allowlist_[key] = {value}; } }; }; @@ -66,21 +66,21 @@ class ReadAccelerationConfigTest : public ::testing::Test { TEST_F(ReadAccelerationConfigTest, ReadsAKeyOnlyLine) { ReadAccelerationConfig("key", consumer_); - EXPECT_THAT(whitelist_.find("key"), Not(Eq(whitelist_.end()))); + EXPECT_THAT(allowlist_.find("key"), Not(Eq(allowlist_.end()))); EXPECT_TRUE(blacklist_.empty()); } TEST_F(ReadAccelerationConfigTest, ReadsABlacklistKeyOnlyLine) { ReadAccelerationConfig("-key", consumer_); - EXPECT_THAT(blacklist_.find("key"), Not(Eq(whitelist_.end()))); - EXPECT_TRUE(whitelist_.empty()); + EXPECT_THAT(blacklist_.find("key"), Not(Eq(allowlist_.end()))); + EXPECT_TRUE(allowlist_.empty()); } TEST_F(ReadAccelerationConfigTest, ReadsAKeyValueLine) { ReadAccelerationConfig("key,value", consumer_); - EXPECT_THAT(whitelist_["key"].value, Eq("value")); + EXPECT_THAT(allowlist_["key"].value, Eq("value")); EXPECT_TRUE(blacklist_.empty()); } @@ -88,13 +88,13 @@ TEST_F(ReadAccelerationConfigTest, ReadsABlackListKeyValueLine) { ReadAccelerationConfig("-key,value", consumer_); EXPECT_THAT(blacklist_["key"].value, Eq("value")); - EXPECT_TRUE(whitelist_.empty()); + EXPECT_TRUE(allowlist_.empty()); } TEST_F(ReadAccelerationConfigTest, KeysAreLeftTrimmed) { ReadAccelerationConfig(" key,value", consumer_); - EXPECT_THAT(whitelist_["key"].value, Eq("value")); + EXPECT_THAT(allowlist_["key"].value, Eq("value")); EXPECT_TRUE(blacklist_.empty()); } @@ -102,58 +102,58 @@ TEST_F(ReadAccelerationConfigTest, BlKeysAreLeftTrimmed) { ReadAccelerationConfig(" -key,value", consumer_); EXPECT_THAT(blacklist_["key"].value, Eq("value")); - EXPECT_TRUE(whitelist_.empty()); + EXPECT_TRUE(allowlist_.empty()); } TEST_F(ReadAccelerationConfigTest, IgnoresCommentedLines) { ReadAccelerationConfig("#key,value", consumer_); - EXPECT_TRUE(whitelist_.empty()); + EXPECT_TRUE(allowlist_.empty()); EXPECT_TRUE(blacklist_.empty()); } TEST_F(ReadAccelerationConfigTest, CommentCanHaveTrailingBlanks) { ReadAccelerationConfig(" #key,value", consumer_); - EXPECT_TRUE(whitelist_.empty()); + EXPECT_TRUE(allowlist_.empty()); EXPECT_TRUE(blacklist_.empty()); } TEST_F(ReadAccelerationConfigTest, CommentsAreOnlyForTheFullLine) { ReadAccelerationConfig("key,value #comment", consumer_); - EXPECT_THAT(whitelist_["key"].value, Eq("value #comment")); + EXPECT_THAT(allowlist_["key"].value, Eq("value #comment")); } TEST_F(ReadAccelerationConfigTest, IgnoresEmptyLines) { ReadAccelerationConfig("", consumer_); - EXPECT_TRUE(whitelist_.empty()); + EXPECT_TRUE(allowlist_.empty()); EXPECT_TRUE(blacklist_.empty()); } TEST_F(ReadAccelerationConfigTest, ParsesMultipleLines) { ReadAccelerationConfig("key1,value1\nkey2,value2\n-key3,value3", consumer_); - EXPECT_THAT(whitelist_["key1"].value, Eq("value1")); - EXPECT_THAT(whitelist_["key2"].value, Eq("value2")); + EXPECT_THAT(allowlist_["key1"].value, Eq("value1")); + EXPECT_THAT(allowlist_["key2"].value, Eq("value2")); EXPECT_THAT(blacklist_["key3"].value, Eq("value3")); } TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithCommentsAndSpaces) { ReadAccelerationConfig("key1,value1\n#comment\n\nkey2,value2", consumer_); - EXPECT_THAT(whitelist_["key1"].value, Eq("value1")); - EXPECT_THAT(whitelist_["key2"].value, Eq("value2")); + EXPECT_THAT(allowlist_["key1"].value, Eq("value1")); + EXPECT_THAT(allowlist_["key2"].value, Eq("value2")); } TEST_F(ReadAccelerationConfigTest, ParsesMultipleLinesWithMissingConfigValues) { ReadAccelerationConfig("key1\nkey2,value2\nkey3\nkey4,value4", consumer_); - EXPECT_THAT(whitelist_["key1"].value, Eq("")); - EXPECT_THAT(whitelist_["key2"].value, Eq("value2")); - EXPECT_THAT(whitelist_["key3"].value, Eq("")); - EXPECT_THAT(whitelist_["key4"].value, Eq("value4")); + EXPECT_THAT(allowlist_["key1"].value, Eq("")); + EXPECT_THAT(allowlist_["key2"].value, Eq("value2")); + EXPECT_THAT(allowlist_["key3"].value, Eq("")); + EXPECT_THAT(allowlist_["key4"].value, Eq("value4")); } TEST(GetAccelerationTestParam, LoadsTestConfig) { diff --git a/tensorflow/lite/micro/tools/make/generate_keil_project.py b/tensorflow/lite/micro/tools/make/generate_keil_project.py index 5a9950cfd96..a022be3a3ab 100644 --- a/tensorflow/lite/micro/tools/make/generate_keil_project.py +++ b/tensorflow/lite/micro/tools/make/generate_keil_project.py @@ -27,7 +27,7 @@ import six def sanitize_xml(unsanitized): - """Uses a whitelist to avoid generating bad XML.""" + """Uses a allowlist to avoid generating bad XML.""" return re.sub(r'[^a-zA-Z0-9+_\-/\\.]', '', six.ensure_str(unsanitized)) diff --git a/tensorflow/lite/toco/tflite/export_test.cc b/tensorflow/lite/toco/tflite/export_test.cc index 16724f7ea46..dd0b1273dca 100644 --- a/tensorflow/lite/toco/tflite/export_test.cc +++ b/tensorflow/lite/toco/tflite/export_test.cc @@ -794,7 +794,7 @@ TEST(OperatorKeyTest, TestFlexWithUnsupportedOp) { EXPECT_EQ(key.type(), ::tflite::BuiltinOperator_CUSTOM); EXPECT_EQ(key.custom_code(), "HashTableV2"); EXPECT_EQ(key.version(), 1); - // While HashTableV2 is excluded from the whitelisted flex op list, eventually + // While HashTableV2 is excluded from the allowlisted flex op list, eventually // it won't be, and the following expectations will need to change as the op // is explicitly blacklisted due to lack of asset support. EXPECT_FALSE(key.is_flex_op()); diff --git a/tensorflow/lite/toco/tflite/operator.cc b/tensorflow/lite/toco/tflite/operator.cc index be539cf6054..bc12d49a115 100644 --- a/tensorflow/lite/toco/tflite/operator.cc +++ b/tensorflow/lite/toco/tflite/operator.cc @@ -24,7 +24,7 @@ limitations under the License. // TODO(ycling): Consider refactoring to extract the LSTM definition out of // graph_transformation module. -#include "tensorflow/lite/delegates/flex/whitelisted_flex_ops.h" +#include "tensorflow/lite/delegates/flex/allowlisted_flex_ops.h" #include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/toco/graph_transformations/lstm_utils.h" #include "tensorflow/lite/toco/model.h" @@ -2116,7 +2116,7 @@ bool ShouldExportAsFlexOp(bool enable_select_tf_ops, return false; } // Check if we can find the `OpDef` for the TensorFlow op. If we can find - // it and it has been whitelisted, export the op as an Flex op. Otherwise, + // it and it has been allowlisted, export the op as an Flex op. Otherwise, // export it as a regular custom op. const tensorflow::OpDef* op_def = nullptr; if (!tensorflow::OpRegistry::Global() @@ -2125,9 +2125,9 @@ bool ShouldExportAsFlexOp(bool enable_select_tf_ops, return false; } - if (!::tflite::flex::IsWhitelistedFlexOp(tensorflow_op_name)) { + if (!::tflite::flex::IsAllowlistedFlexOp(tensorflow_op_name)) { LOG(WARNING) << "Op " << tensorflow_op_name - << " is a valid TensorFlow op but has not been whitelisted for" + << " is a valid TensorFlow op but has not been allowlisted for" " the TensorFlow Lite flex op set."; return false; } diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/README.md b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/README.md index a5baff10a28..590c15cc133 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/README.md +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/README.md @@ -156,7 +156,7 @@ To do so, we utilize the `preprocess_coco_minival` Python binary as follows: bazel run //tensorflow/lite/tools/evaluation/tasks/coco_object_detection:preprocess_coco_minival -- \ --images_folder=/path/to/val2014 \ --instances_file=/path/to/instances_val2014.json \ - --whitelist_file=/path/to/minival_whitelist.txt \ + --allowlist_file=/path/to/minival_allowlist.txt \ --output_folder=/path/to/output/folder ``` diff --git a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py index ab086538a04..de9ac65e457 100644 --- a/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py +++ b/tensorflow/lite/tools/evaluation/tasks/coco_object_detection/preprocess_coco_minival.py @@ -16,13 +16,13 @@ The 2014 validation images & annotations can be downloaded from: http://cocodataset.org/#download -The minival image ID whitelist, a subset of the 2014 validation set, can be +The minival image ID allowlist, a subset of the 2014 validation set, can be found here: https://github.com/tensorflow/models/blob/master/research/object_detection/data/mscoco_minival_ids.txt. This script takes in the original images folder, instances JSON file and -image ID whitelist and produces the following in the specified output folder: -A subfolder for whitelisted images (images/), and a file (ground_truth.pbtxt) +image ID allowlist and produces the following in the specified output folder: +A subfolder for allowlisted images (images/), and a file (ground_truth.pbtxt) containing an instance of tflite::evaluation::ObjectDetectionGroundTruth. """ @@ -40,17 +40,17 @@ from tensorflow.lite.tools.evaluation.proto import evaluation_stages_pb2 def _get_ground_truth_detections(instances_file, - whitelist_file=None, + allowlist_file=None, num_images=None): - """Processes the annotations JSON file and returns ground truth data corresponding to whitelisted image IDs. + """Processes the annotations JSON file and returns ground truth data corresponding to allowlisted image IDs. Args: instances_file: COCO instances JSON file, usually named as instances_val20xx.json. - whitelist_file: File containing COCO minival image IDs to whitelist for + allowlist_file: File containing COCO minival image IDs to allowlist for evaluation, one per line. - num_images: Number of whitelisted images to pre-process. First num_images - are chosen based on sorted list of filenames. If None, all whitelisted + num_images: Number of allowlisted images to pre-process. First num_images + are chosen based on sorted list of filenames. If None, all allowlisted files are preprocessed. Returns: @@ -70,17 +70,17 @@ def _get_ground_truth_detections(instances_file, image_data = collections.OrderedDict() all_file_names = [] - # Read whitelist. - if whitelist_file is not None: - with open(whitelist_file, 'r') as whitelist: - image_id_whitelist = set([int(x) for x in whitelist.readlines()]) + # Read allowlist. + if allowlist_file is not None: + with open(allowlist_file, 'r') as allowlist: + image_id_allowlist = set([int(x) for x in allowlist.readlines()]) else: - image_id_whitelist = [image['id'] for image in data_dict['images']] + image_id_allowlist = [image['id'] for image in data_dict['images']] # Get image names and dimensions. for image_dict in data_dict['images']: image_id = image_dict['id'] - if image_id not in image_id_whitelist: + if image_id not in image_id_allowlist: continue image_data_dict = {} image_data_dict['id'] = image_dict['id'] @@ -99,7 +99,7 @@ def _get_ground_truth_detections(instances_file, # Get detected object annotations per image. for annotation_dict in data_dict['annotations']: image_id = annotation_dict['image_id'] - if image_id not in image_id_whitelist: + if image_id not in image_id_allowlist: continue if image_id not in image_data: continue @@ -133,7 +133,7 @@ def _dump_data(ground_truth_detections, images_folder_path, output_folder_path): """Dumps images & data from ground-truth objects into output_folder_path. The following are created in output_folder_path: - images/: sub-folder for whitelisted validation images. + images/: sub-folder for allowlisted validation images. ground_truth.pb: A binary proto file containing all ground-truth object-sets. @@ -193,14 +193,14 @@ def _parse_args(): help='Full path of the input JSON file, like instances_val20xx.json.', required=True) parser.add_argument( - '--whitelist_file', + '--allowlist_file', type=str, help='File with COCO image ids to preprocess, one on each line.', required=False) parser.add_argument( '--num_images', type=int, - help='Number of whitelisted images to preprocess into the output folder.', + help='Number of allowlisted images to preprocess into the output folder.', required=False) parser.add_argument( '--output_folder', @@ -213,6 +213,6 @@ def _parse_args(): if __name__ == '__main__': args = _parse_args() ground_truths = _get_ground_truth_detections(args.instances_file, - args.whitelist_file, + args.allowlist_file, args.num_images) _dump_data(ground_truths, args.images_folder, args.output_folder) diff --git a/tensorflow/lite/tools/optimize/quantize_model.h b/tensorflow/lite/tools/optimize/quantize_model.h index 29f581d2b35..18cd26d3585 100644 --- a/tensorflow/lite/tools/optimize/quantize_model.h +++ b/tensorflow/lite/tools/optimize/quantize_model.h @@ -55,7 +55,7 @@ TfLiteStatus QuantizeModel(flatbuffers::FlatBufferBuilder* builder, const TensorType& output_type, bool allow_float, ErrorReporter* error_reporter); -// Same as above, but enables only quantizing a whitelist of operations, +// Same as above, but enables only quantizing an allowlist of operations, // specified by their operator output name. // // Note: This is a private API, subject to change. diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index 9b42742058c..d68119644c1 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -158,6 +158,6 @@ _exported_dunders = set([ '__monolithic_build__', ]) -# Expose symbols minus dunders, unless they are whitelisted above. +# Expose symbols minus dunders, unless they are allowlisted above. # This is necessary to export our dunders. __all__ = [s for s in dir() if s in _exported_dunders or not s.startswith('_')] diff --git a/tensorflow/python/autograph/converters/call_trees.py b/tensorflow/python/autograph/converters/call_trees.py index 505925650d1..33b5966058b 100644 --- a/tensorflow/python/autograph/converters/call_trees.py +++ b/tensorflow/python/autograph/converters/call_trees.py @@ -177,7 +177,7 @@ class CallTreeTransformer(converter.Base): # Calls to pdb.set_trace or ipdb.set_trace are never converted. We don't use # the normal mechanisms to bypass these literals because they are sensitive # to the frame they are being called from. - # TODO(mdan): Generalize this to a "static whitelist" config. + # TODO(mdan): Generalize this to a "static allowlist" config. if full_name in ('pdb.set_trace', 'ipdb.set_trace', 'breakpoint'): global set_trace_warned if not set_trace_warned: diff --git a/tensorflow/python/autograph/core/converter_testing.py b/tensorflow/python/autograph/core/converter_testing.py index 4301cf898bf..9f2604dec94 100644 --- a/tensorflow/python/autograph/core/converter_testing.py +++ b/tensorflow/python/autograph/core/converter_testing.py @@ -32,16 +32,16 @@ from tensorflow.python.framework import ops from tensorflow.python.platform import test -def whitelist(f): +def allowlist(f): """Helper that marks a callable as whtelitisted.""" - if 'whitelisted_module_for_testing' not in sys.modules: - whitelisted_mod = imp.new_module('whitelisted_module_for_testing') - sys.modules['whitelisted_module_for_testing'] = whitelisted_mod + if 'allowlisted_module_for_testing' not in sys.modules: + allowlisted_mod = imp.new_module('allowlisted_module_for_testing') + sys.modules['allowlisted_module_for_testing'] = allowlisted_mod config.CONVERSION_RULES = ( - (config.DoNotConvert('whitelisted_module_for_testing'),) + + (config.DoNotConvert('allowlisted_module_for_testing'),) + config.CONVERSION_RULES) - f.__module__ = 'whitelisted_module_for_testing' + f.__module__ = 'allowlisted_module_for_testing' def is_inside_generated_code(): diff --git a/tensorflow/python/autograph/g3doc/reference/functions.md b/tensorflow/python/autograph/g3doc/reference/functions.md index 83c4fbe9bea..48bf052f298 100644 --- a/tensorflow/python/autograph/g3doc/reference/functions.md +++ b/tensorflow/python/autograph/g3doc/reference/functions.md @@ -44,18 +44,18 @@ are handled correctly. The following types of functions are not converted: - * functions already converted - * functions defined in in a whitelisted module (see autograph/core/config.py) - * non-Python functions (such as native bindings) - * `print`, `pdb.set_trace`, `ipdb.set_trace` - * most built-in functions (exceptions are listed in +* functions already converted +* functions defined in in a allowlisted module (see autograph/core/config.py) +* non-Python functions (such as native bindings) +* `print`, `pdb.set_trace`, `ipdb.set_trace` +* most built-in functions (exceptions are listed in autograph/operators/py_builtins.py) - * constructors - * functions without source code attached (prints a warning)(see +* constructors +* functions without source code attached (prints a warning)(see [limitations](limitations.md)) - * generator functions (prints a warning) - * iterator protocol methods (`__next__`, `__iter__`) - * context manager methods (`__enter__`, `__exit__`) +* generator functions (prints a warning) +* iterator protocol methods (`__next__`, `__iter__`) +* context manager methods (`__enter__`, `__exit__`) When AutoGraph encounters a function that it cannot convert outside of this list, it prints a warning. diff --git a/tensorflow/python/autograph/impl/api.py b/tensorflow/python/autograph/impl/api.py index a5e1ab1705f..8c7093c864d 100644 --- a/tensorflow/python/autograph/impl/api.py +++ b/tensorflow/python/autograph/impl/api.py @@ -342,16 +342,16 @@ def converted_call(f, raise ValueError('either caller_fn_scope or options must have a value') options = caller_fn_scope.callopts - if conversion.is_in_whitelist_cache(f, options): - logging.log(2, 'Whitelisted %s: from cache', f) + if conversion.is_in_allowlist_cache(f, options): + logging.log(2, 'Allowlisted %s: from cache', f) return _call_unconverted(f, args, kwargs, options, False) if ag_ctx.control_status_ctx().status == ag_ctx.Status.DISABLED: - logging.log(2, 'Whitelisted: %s: AutoGraph is disabled in context', f) + logging.log(2, 'Allowlisted: %s: AutoGraph is disabled in context', f) return _call_unconverted(f, args, kwargs, options, False) if is_autograph_artifact(f): - logging.log(2, 'Permanently whitelisted: %s: AutoGraph artifact', f) + logging.log(2, 'Permanently allowed: %s: AutoGraph artifact', f) return _call_unconverted(f, args, kwargs, options) # If this is a partial, unwrap it and redo all the checks. @@ -385,7 +385,7 @@ def converted_call(f, if conversion.is_unsupported(f): return _call_unconverted(f, args, kwargs, options) - if not options.user_requested and conversion.is_whitelisted(f): + if not options.user_requested and conversion.is_allowlisted(f): return _call_unconverted(f, args, kwargs, options) # internal_convert_user_code is for example turned off when issuing a dynamic @@ -425,13 +425,13 @@ def converted_call(f, return _fall_back_unconverted(f, args, kwargs, options, e) if not hasattr(target_entity, '__code__'): - logging.log(2, 'Permanently whitelisted: %s: native binding', + logging.log(2, 'Permanently allowed: %s: native binding', target_entity) return _call_unconverted(f, args, kwargs, options) elif (hasattr(target_entity.__code__, 'co_filename') and target_entity.__code__.co_filename == ''): # TODO(mdan): __globals__['txt'] might work in Py3. - logging.log(2, 'Permanently whitelisted: %s: dynamic code (exec?)', + logging.log(2, 'Permanently allowed: %s: dynamic code (exec?)', target_entity) return _call_unconverted(f, args, kwargs, options) @@ -462,7 +462,7 @@ def converted_call(f, def _call_unconverted(f, args, kwargs, options, update_cache=True): """Calls the original function without converting with AutoGraph.""" if update_cache: - conversion.cache_whitelisted(f, options) + conversion.cache_allowlisted(f, options) if inspect.ismethod(f) and isinstance(f.__self__, function.TfMethodTarget): return f.__self__.call(args, kwargs) @@ -482,7 +482,7 @@ def _fall_back_unconverted(f, args, kwargs, options, exc): 'To silence this warning, decorate the function with' ' @tf.autograph.experimental.do_not_convert') if isinstance(exc, errors.UnsupportedLanguageElementError): - if not conversion.is_in_whitelist_cache(f, options): + if not conversion.is_in_allowlist_cache(f, options): logging.warn(warning_template, f, '', exc) else: file_bug_message = ( @@ -516,7 +516,7 @@ def tf_convert(f, ctx, convert_by_default=True, user_requested=False): ctx: ag_ctx.ControlStatusCtx, the Autograph context in which `f` is used. convert_by_default: bool, whether to use AutoGraph when the context doesn't specify. - user_requested: bool, whether to ignore the conversion whitelist. See + user_requested: bool, whether to ignore the conversion allowlist. See ConversionOptions.user_requested. Returns: diff --git a/tensorflow/python/autograph/impl/api_test.py b/tensorflow/python/autograph/impl/api_test.py index 118258b3b91..5b885af43ac 100644 --- a/tensorflow/python/autograph/impl/api_test.py +++ b/tensorflow/python/autograph/impl/api_test.py @@ -203,14 +203,14 @@ class ApiTest(test.TestCase): z = x + y return z - test_method_whitelisted = api.do_not_convert(test_method) + test_method_allowlisted = api.do_not_convert(test_method) tc = TestClass() - self.assertTrue(tf_inspect.ismethod(tc.test_method_whitelisted)) + self.assertTrue(tf_inspect.ismethod(tc.test_method_allowlisted)) # Because the wrapped function is not generated, we can't preserve its # arg spec. self.assertEqual((), - tuple(function_utils.fn_args(tc.test_method_whitelisted))) + tuple(function_utils.fn_args(tc.test_method_allowlisted))) def test_do_not_convert_callable_object(self): @@ -521,12 +521,12 @@ class ApiTest(test.TestCase): ag_logging.set_verbosity(0, False) os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' - def test_converted_call_partial_of_whitelisted_function(self): + def test_converted_call_partial_of_allowlisted_function(self): def test_fn(_): self.assertFalse(converter_testing.is_inside_generated_code()) - converter_testing.whitelist(test_fn) + converter_testing.allowlist(test_fn) api.converted_call( functools.partial(test_fn, None), (), None, options=DEFAULT_RECURSIVE) @@ -563,7 +563,7 @@ class ApiTest(test.TestCase): f, (g, constant_op.constant(1)), None, options=DEFAULT_RECURSIVE) self.assertEqual(self.evaluate(x), 1) - def test_converted_call_forced_when_explicitly_whitelisted(self): + def test_converted_call_forced_when_explicitly_allowlisted(self): @api.do_not_convert() def f(x): @@ -606,7 +606,7 @@ class ApiTest(test.TestCase): self.assertIsNotNone( api.converted_call(f, (1, 2, 3, 4), None, options=opts)) - def test_converted_call_whitelisted_method(self): + def test_converted_call_allowlisted_method(self): class TestClass(object): @@ -614,19 +614,19 @@ class ApiTest(test.TestCase): return converter_testing.is_inside_generated_code() obj = TestClass() - converter_testing.whitelist(obj.method.__func__) + converter_testing.allowlist(obj.method.__func__) self.assertFalse( api.converted_call(obj.method, (), {}, options=DEFAULT_RECURSIVE)) - def test_converted_call_whitelisted_method_via_owner(self): + def test_converted_call_allowlisted_method_via_owner(self): class TestClass(object): def method(self): return converter_testing.is_inside_generated_code() - converter_testing.whitelist(TestClass) + converter_testing.allowlist(TestClass) obj = TestClass() self.assertFalse( @@ -852,7 +852,7 @@ class ApiTest(test.TestCase): # invocation would fail. self.assertEqual(self.evaluate(call_in_default_context()), 1) - def test_converted_call_caching_of_whitelisted_bound_methods(self): + def test_converted_call_caching_of_allowlisted_bound_methods(self): class TestClass(object): @@ -863,7 +863,7 @@ class ApiTest(test.TestCase): return self.__private # TODO(mdan): Refactor to avoid this use of global state. - cache_size_before = len(conversion._WHITELIST_CACHE) + cache_size_before = len(conversion._ALLOWLIST_CACHE) # First invocation with fallback on, to allow recording it into cache. os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '0' @@ -871,15 +871,15 @@ class ApiTest(test.TestCase): api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE) os.environ['AUTOGRAPH_STRICT_CONVERSION'] = '1' - # Entry should be added to the whitelist cache. - self.assertEqual(len(conversion._WHITELIST_CACHE), cache_size_before + 1) + # Entry should be added to the allowlist cache. + self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1) # A second invocation should go through even with fallback off. tc = TestClass() api.converted_call(tc.test_method, (), None, options=DEFAULT_RECURSIVE) - # No new entries should appear in the whitelist cache. - self.assertEqual(len(conversion._WHITELIST_CACHE), cache_size_before + 1) + # No new entries should appear in the allowlist cache. + self.assertEqual(len(conversion._ALLOWLIST_CACHE), cache_size_before + 1) def test_context_tracking_direct_calls(self): @@ -1102,7 +1102,7 @@ class ApiTest(test.TestCase): test_fn(ag_ctx.ControlStatusCtx(status=ag_ctx.Status.UNSPECIFIED)) - def test_tf_convert_whitelisted_method(self): + def test_tf_convert_allowlisted_method(self): if six.PY2: self.skipTest('Test bank not comptible with Python 2.') @@ -1112,7 +1112,7 @@ class ApiTest(test.TestCase): def method(self): return converter_testing.is_inside_generated_code() - converter_testing.whitelist(TestClass.method) + converter_testing.allowlist(TestClass.method) obj = TestClass() converted_call = api.tf_convert( diff --git a/tensorflow/python/autograph/impl/conversion.py b/tensorflow/python/autograph/impl/conversion.py index 3c1f7e97bde..d73b35283f1 100644 --- a/tensorflow/python/autograph/impl/conversion.py +++ b/tensorflow/python/autograph/impl/conversion.py @@ -31,7 +31,7 @@ from tensorflow.python.eager import function from tensorflow.python.util import tf_inspect -_WHITELIST_CACHE = cache.UnboundInstanceCache() +_ALLOWLIST_CACHE = cache.UnboundInstanceCache() def _is_of_known_loaded_module(f, module_name): @@ -80,53 +80,53 @@ def is_unsupported(o): '{} appears to be decorated by wrapt, which is not yet supported' ' by AutoGraph. The function will run as-is.' ' You may still apply AutoGraph before the wrapt decorator.'.format(o)) - logging.log(2, 'Permanently whitelisted: %s: wrapt decorated', o) + logging.log(2, 'Permanently allowed: %s: wrapt decorated', o) return True if _is_known_loaded_type(o, 'functools', '_lru_cache_wrapper'): - logging.log(2, 'Permanently whitelisted: %s: lru_cache', o) + logging.log(2, 'Permanently allowed: %s: lru_cache', o) return True - # Constructors are permanently whitelisted. + # Constructors are permanently allowed. # TODO(mdan): Toggle as experimental feature instead. # TODO(b/124016764): Remove this limitation. if inspect_utils.isconstructor(o): - logging.log(2, 'Permanently whitelisted: %s: constructor', o) + logging.log(2, 'Permanently allowed: %s: constructor', o) return True - # Other built-in modules are permanently whitelisted. + # Other built-in modules are permanently allowed. # TODO(mdan): Figure out how to do this consistently for all stdlib modules. if any( _is_of_known_loaded_module(o, m) for m in ('collections', 'pdb', 'copy', 'inspect', 're')): - logging.log(2, 'Permanently whitelisted: %s: part of builtin module', o) + logging.log(2, 'Permanently allowed: %s: part of builtin module', o) return True - # Custom ops and kernels are also permanently whitelisted. + # Custom ops and kernels are also permanently allowed. # See tensorflow.framework.load_library. if (hasattr(o, '__module__') and hasattr(o.__module__, '_IS_TENSORFLOW_PLUGIN')): - logging.log(2, 'Permanently whitelisted: %s: TensorFlow plugin', o) + logging.log(2, 'Permanently allowed: %s: TensorFlow plugin', o) return True return False # TODO(mdan): allow_namedtuple_subclass should be hardcoded to True. -def is_whitelisted( +def is_allowlisted( o, check_call_override=True, allow_namedtuple_subclass=False): - """Checks whether an entity is whitelisted for use in graph mode. + """Checks whether an entity is allowed for use in graph mode. - Examples of whitelisted entities include all members of the tensorflow + Examples of allowed entities include all members of the tensorflow package. Args: o: A Python entity. check_call_override: Reserved for internal use. When set to `False`, it - disables the rule according to which classes are whitelisted if their - __call__ method is whitelisted. + disables the rule according to which classes are allowed if their + __call__ method is allowed. allow_namedtuple_subclass: Reserved for internal use. When `True`, - namedtuple subclasses are not whitelisted. + namedtuple subclasses are not allowed. Returns: Boolean @@ -144,10 +144,10 @@ def is_whitelisted( for rule in config.CONVERSION_RULES: action = rule.get_action(m) if action == config.Action.CONVERT: - logging.log(2, 'Not whitelisted: %s: %s', o, rule) + logging.log(2, 'Not allowed: %s: %s', o, rule) return False elif action == config.Action.DO_NOT_CONVERT: - logging.log(2, 'Whitelisted: %s: %s', o, rule) + logging.log(2, 'Allowlisted: %s: %s', o, rule) return True # The check for __code__ below is because isgeneratorfunction crashes @@ -156,26 +156,26 @@ def is_whitelisted( logging.warn( 'Entity %s appears to be a generator function. It will not be converted' ' by AutoGraph.', o) - logging.log(2, 'Whitelisted: %s: generator functions are not converted', o) + logging.log(2, 'Allowlisted: %s: generator functions are not converted', o) return True if (check_call_override and not tf_inspect.isclass(o) and hasattr(o, '__call__')): - # Callable objects: whitelisted if their __call__ method is. + # Callable objects: allowed if their __call__ method is. # The type check avoids infinite recursion around the __call__ method # of function objects. - if (type(o) != type(o.__call__)) and is_whitelisted(o.__call__): # pylint: disable=unidiomatic-typecheck - logging.log(2, 'Whitelisted: %s: object __call__ whitelisted', o) + if (type(o) != type(o.__call__)) and is_allowlisted(o.__call__): # pylint: disable=unidiomatic-typecheck + logging.log(2, 'Allowlisted: %s: object __call__ allowed', o) return True owner_class = None if tf_inspect.ismethod(o): - # Methods of whitelisted classes are also whitelisted, even if they are + # Methods of allowed classes are also allowed, even if they are # bound via user subclasses. # # For example, suppose `tf.Foo` has a method called `bar`, and `baz` is - # defined as below. `tf.Foo` is whitelisted. Then `baz.bar` is also - # whitelisted. + # defined as below. `tf.Foo` is allowed. Then `baz.bar` is also + # allowed. # # class Custom(tf.Foo): # pass @@ -183,22 +183,22 @@ def is_whitelisted( # baz = Custom() # # For the example above, if `Custom` did overload `bar`, then it would no - # longer be whitelisted. + # longer be allowed. owner_class = inspect_utils.getmethodclass(o) if owner_class is function.TfMethodTarget: owner_class = o.__self__.target_class if owner_class is not None: if issubclass(owner_class, unittest.TestCase): - logging.log(2, 'Whitelisted: %s: method of TestCase subclass', o) + logging.log(2, 'Allowlisted: %s: method of TestCase subclass', o) return True owner_class = inspect_utils.getdefiningclass(o, owner_class) - if is_whitelisted( + if is_allowlisted( owner_class, check_call_override=False, allow_namedtuple_subclass=True): - logging.log(2, 'Whitelisted: %s: owner is whitelisted %s', o, + logging.log(2, 'Allowlisted: %s: owner is allowed %s', o, owner_class) return True @@ -208,27 +208,27 @@ def is_whitelisted( # graph mode since they are just containers. if allow_namedtuple_subclass: if not any(inspect_utils.isnamedtuple(base) for base in o.__bases__): - logging.log(2, 'Whitelisted: %s: named tuple', o) + logging.log(2, 'Allowlisted: %s: named tuple', o) return True else: - logging.log(2, 'Whitelisted: %s: named tuple or subclass', o) + logging.log(2, 'Allowlisted: %s: named tuple or subclass', o) return True - logging.log(2, 'Not whitelisted: %s: default rule', o) + logging.log(2, 'Not allowed: %s: default rule', o) return False -def is_in_whitelist_cache(entity, options): +def is_in_allowlist_cache(entity, options): try: - return _WHITELIST_CACHE.has(entity, options) + return _ALLOWLIST_CACHE.has(entity, options) except TypeError: # Catch-all for entities that are unhashable or don't allow weakrefs. return False -def cache_whitelisted(entity, options): +def cache_allowlisted(entity, options): try: - _WHITELIST_CACHE[entity][options] = True + _ALLOWLIST_CACHE[entity][options] = True except TypeError: # Catch-all for entities that are unhashable or don't allow weakrefs. pass diff --git a/tensorflow/python/autograph/impl/conversion_test.py b/tensorflow/python/autograph/impl/conversion_test.py index 24d93e18b24..d2b2f111729 100644 --- a/tensorflow/python/autograph/impl/conversion_test.py +++ b/tensorflow/python/autograph/impl/conversion_test.py @@ -43,16 +43,16 @@ class ConversionTest(test.TestCase): options=converter.ConversionOptions(recursive=True), autograph_module=api) - def test_is_whitelisted(self): + def test_is_allowlisted(self): def test_fn(): return constant_op.constant(1) - self.assertFalse(conversion.is_whitelisted(test_fn)) - self.assertTrue(conversion.is_whitelisted(utils)) - self.assertTrue(conversion.is_whitelisted(constant_op.constant)) + self.assertFalse(conversion.is_allowlisted(test_fn)) + self.assertTrue(conversion.is_allowlisted(utils)) + self.assertTrue(conversion.is_allowlisted(constant_op.constant)) - def test_is_whitelisted_tensorflow_like(self): + def test_is_allowlisted_tensorflow_like(self): tf_like = imp.new_module('tensorflow_foo') def test_fn(): @@ -60,13 +60,13 @@ class ConversionTest(test.TestCase): tf_like.test_fn = test_fn test_fn.__module__ = tf_like - self.assertFalse(conversion.is_whitelisted(tf_like.test_fn)) + self.assertFalse(conversion.is_allowlisted(tf_like.test_fn)) - def test_is_whitelisted_callable_whitelisted_call(self): + def test_is_allowlisted_callable_allowlisted_call(self): - whitelisted_mod = imp.new_module('test_whitelisted_call') - sys.modules['test_whitelisted_call'] = whitelisted_mod - config.CONVERSION_RULES = ((config.DoNotConvert('test_whitelisted_call'),) + + allowlisted_mod = imp.new_module('test_allowlisted_call') + sys.modules['test_allowlisted_call'] = allowlisted_mod + config.CONVERSION_RULES = ((config.DoNotConvert('test_allowlisted_call'),) + config.CONVERSION_RULES) class TestClass(object): @@ -74,14 +74,14 @@ class ConversionTest(test.TestCase): def __call__(self): pass - def whitelisted_method(self): + def allowlisted_method(self): pass - TestClass.__module__ = 'test_whitelisted_call' + TestClass.__module__ = 'test_allowlisted_call' if six.PY2: - TestClass.__call__.__func__.__module__ = 'test_whitelisted_call' + TestClass.__call__.__func__.__module__ = 'test_allowlisted_call' else: - TestClass.__call__.__module__ = 'test_whitelisted_call' + TestClass.__call__.__module__ = 'test_allowlisted_call' class Subclass(TestClass): @@ -90,20 +90,21 @@ class ConversionTest(test.TestCase): tc = Subclass() - self.assertTrue(conversion.is_whitelisted(TestClass.__call__)) - self.assertTrue(conversion.is_whitelisted(tc)) - self.assertTrue(conversion.is_whitelisted(tc.__call__)) - self.assertTrue(conversion.is_whitelisted(tc.whitelisted_method)) - self.assertFalse(conversion.is_whitelisted(Subclass)) - self.assertFalse(conversion.is_whitelisted(tc.converted_method)) + self.assertTrue(conversion.is_allowlisted(TestClass.__call__)) + self.assertTrue(conversion.is_allowlisted(tc)) + self.assertTrue(conversion.is_allowlisted(tc.__call__)) + self.assertTrue(conversion.is_allowlisted(tc.allowlisted_method)) + self.assertFalse(conversion.is_allowlisted(Subclass)) + self.assertFalse(conversion.is_allowlisted(tc.converted_method)) + + def test_is_allowlisted_tfmethodwrapper(self): - def test_is_whitelisted_tfmethodwrapper(self): class TestClass(object): def member_function(self): pass - TestClass.__module__ = 'test_whitelisted_call' + TestClass.__module__ = 'test_allowlisted_call' test_obj = TestClass() def test_fn(self): @@ -114,14 +115,14 @@ class ConversionTest(test.TestCase): function.TfMethodTarget( weakref.ref(test_obj), test_obj.member_function)) - self.assertTrue(conversion.is_whitelisted(bound_method)) + self.assertTrue(conversion.is_allowlisted(bound_method)) - def test_is_whitelisted_pybind(self): + def test_is_allowlisted_pybind(self): test_object = pybind_for_testing.TestClassDef() with test.mock.patch.object(config, 'CONVERSION_RULES', ()): # TODO(mdan): This should return True for functions and methods. - # Note: currently, native bindings are whitelisted by a separate check. - self.assertFalse(conversion.is_whitelisted(test_object.method)) + # Note: currently, native bindings are allowlisted by a separate check. + self.assertFalse(conversion.is_allowlisted(test_object.method)) if __name__ == '__main__': diff --git a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py index d8d13fefb0f..58d2421cb84 100644 --- a/tensorflow/python/autograph/pyct/common_transformers/anf_test.py +++ b/tensorflow/python/autograph/pyct/common_transformers/anf_test.py @@ -477,12 +477,14 @@ class AnfConfiguredTest(AnfTestBase): def test_anf_some_function_calls(self): # Another example specific configuration that differs from the default: # Moving all arguments out of some function calls but leaving others be. - whitelist = ['foo'] + allowlist = ['foo'] + def transform(parent, field, child): del field del child func_name = parent.func.id - return str(func_name) in whitelist + return str(func_name) in allowlist + config = [(anf.ASTEdgePattern(gast.Call, anf.ANY, anf.ANY), transform)] def test_function(x, foo, bar): diff --git a/tensorflow/python/autograph/pyct/error_utils.py b/tensorflow/python/autograph/pyct/error_utils.py index 3e9b8754c3c..5ab45d8a0fd 100644 --- a/tensorflow/python/autograph/pyct/error_utils.py +++ b/tensorflow/python/autograph/pyct/error_utils.py @@ -24,10 +24,9 @@ from tensorflow.python.autograph.pyct import origin_info class FrameInfo( - collections.namedtuple( - 'FrameInfo', - ('filename', 'lineno', 'function_name', 'code', 'is_converted', - 'is_whitelisted'))): + collections.namedtuple('FrameInfo', + ('filename', 'lineno', 'function_name', 'code', + 'is_converted', 'is_allowlisted'))): __slots__ = () @@ -75,7 +74,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): origin_info.create_source_map. converter_filename: str, the file path of the converted module. Call frames corresponding to this module are elided and their preceding frames are - marked as whitelisted. Note that frames enclosing converted code are + marked as allowlisted. Note that frames enclosing converted code are dropped using a different mechanism. Returns: @@ -93,7 +92,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): function_name=origin.function_name, code=origin.source_code_line, is_converted=True, - is_whitelisted=False) + is_allowlisted=False) result_frames.append(fi) break @@ -107,7 +106,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): function_name=prev.function_name, code=prev.code, is_converted=False, - is_whitelisted=True) + is_allowlisted=True) result_frames[-1] = fi continue @@ -117,7 +116,7 @@ def _stack_trace_inside_mapped_code(tb, source_map, converter_filename): function_name=function_name, code=text, is_converted=False, - is_whitelisted=False) + is_allowlisted=False) result_frames.append(fi) return tuple(result_frames) @@ -188,7 +187,7 @@ class ErrorMetadataBase(object): frame_info.function_name) if frame_info.is_converted: formatted_line += ' *' - elif frame_info.is_whitelisted: + elif frame_info.is_allowlisted: formatted_line += ' **' lines.append(formatted_line) diff --git a/tensorflow/python/data/ops/dataset_ops.py b/tensorflow/python/data/ops/dataset_ops.py index 586b82e9ca6..e235751617f 100644 --- a/tensorflow/python/data/ops/dataset_ops.py +++ b/tensorflow/python/data/ops/dataset_ops.py @@ -2250,7 +2250,7 @@ class DatasetV1(DatasetV2): # by value _make_dataset() function would try to capture these variant # tensor dataset inputs, which are marked as stateful ops and would throw # an error if we try and capture them. We therefore traverse the graph - # to find all these ops and whitelist them so that the capturing + # to find all these ops and allowlist them so that the capturing # logic instead of throwing an error recreates these ops which is what was # happening before. all_ds_ops = traverse.obtain_all_variant_tensor_ops(self) @@ -2258,7 +2258,7 @@ class DatasetV1(DatasetV2): # NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is # a 0-argument function. - @function.Defun(capture_by_value=True, whitelisted_stateful_ops=all_ds_ops) + @function.Defun(capture_by_value=True, allowlisted_stateful_ops=all_ds_ops) def _make_dataset(): """Factory function for a dataset.""" # NOTE(mrry): `Defun` does not capture the graph-level seed from the diff --git a/tensorflow/python/debug/cli/analyzer_cli.py b/tensorflow/python/debug/cli/analyzer_cli.py index 0be1f5894c2..49b48fd2dcc 100644 --- a/tensorflow/python/debug/cli/analyzer_cli.py +++ b/tensorflow/python/debug/cli/analyzer_cli.py @@ -1246,8 +1246,8 @@ class DebugAnalyzer(object): parsed = self._arg_parsers["list_source"].parse_args(args) source_list = source_utils.list_source_files_against_dump( self._debug_dump, - path_regex_whitelist=parsed.path_filter, - node_name_regex_whitelist=parsed.node_name_filter) + path_regex_allowlist=parsed.path_filter, + node_name_regex_allowlist=parsed.node_name_filter) top_lines = [ RL("List of source files that created nodes in this run", "bold")] diff --git a/tensorflow/python/debug/cli/analyzer_cli_test.py b/tensorflow/python/debug/cli/analyzer_cli_test.py index cba446e8157..7c662faa59c 100644 --- a/tensorflow/python/debug/cli/analyzer_cli_test.py +++ b/tensorflow/python/debug/cli/analyzer_cli_test.py @@ -1578,9 +1578,9 @@ class AnalyzerCLISimpleMulAddTest(test_util.TensorFlowTestCase): def testListSourceWithCompiledPythonSourceWorks(self): def fake_list_source_files_against_dump(dump, - path_regex_whitelist=None, - node_name_regex_whitelist=None): - del dump, path_regex_whitelist, node_name_regex_whitelist + path_regex_allowlist=None, + node_name_regex_allowlist=None): + del dump, path_regex_allowlist, node_name_regex_allowlist return [("compiled_1.pyc", False, 10, 20, 30, 4), ("compiled_2.pyo", False, 10, 20, 30, 5), ("uncompiled.py", False, 10, 20, 30, 6)] diff --git a/tensorflow/python/debug/lib/check_numerics_callback.py b/tensorflow/python/debug/lib/check_numerics_callback.py index bd88ec5e122..da937a09d0b 100644 --- a/tensorflow/python/debug/lib/check_numerics_callback.py +++ b/tensorflow/python/debug/lib/check_numerics_callback.py @@ -38,7 +38,7 @@ from tensorflow.python.util.tf_export import tf_export # Many ops have benign NaN outputs, and running them with check_numerics # on will create unwanted errors -# TODO(b/142497024): Replace this whitelist with function decorators in the ops +# TODO(b/142497024): Replace this allowlist with function decorators in the ops IGNORE_OP_OUTPUTS = ( # For FusedBatchNorm, if the input tensor is empty then batch_mean and # batch_variance will be NaN. reserve_space holds intermediate values diff --git a/tensorflow/python/debug/lib/debug_utils.py b/tensorflow/python/debug/lib/debug_utils.py index eb21694ba2f..61575cdef76 100644 --- a/tensorflow/python/debug/lib/debug_utils.py +++ b/tensorflow/python/debug/lib/debug_utils.py @@ -83,16 +83,16 @@ def watch_graph(run_options, graph, debug_ops="DebugIdentity", debug_urls=None, - node_name_regex_whitelist=None, - op_type_regex_whitelist=None, - tensor_dtype_regex_whitelist=None, + node_name_regex_allowlist=None, + op_type_regex_allowlist=None, + tensor_dtype_regex_allowlist=None, tolerate_debug_op_creation_failures=False, global_step=-1, reset_disk_byte_usage=False): """Add debug watches to `RunOptions` for a TensorFlow graph. - To watch all `Tensor`s on the graph, let both `node_name_regex_whitelist` - and `op_type_regex_whitelist` be the default (`None`). + To watch all `Tensor`s on the graph, let both `node_name_regex_allowlist` + and `op_type_regex_allowlist` be the default (`None`). N.B.: 1. Under certain circumstances, the `Tensor` may not get actually watched @@ -114,17 +114,17 @@ def watch_graph(run_options, For debug op types with customizable attributes, each debug op name string can optionally contain a list of attribute names, in the syntax of: debug_op_name(attr_name_1=attr_value_1;attr_name_2=attr_value_2;...) - node_name_regex_whitelist: Regular-expression whitelist for node_name, + node_name_regex_allowlist: Regular-expression allowlist for node_name, e.g., `"(weight_[0-9]+|bias_.*)"` - op_type_regex_whitelist: Regular-expression whitelist for the op type of + op_type_regex_allowlist: Regular-expression allowlist for the op type of nodes, e.g., `"(Variable|Add)"`. - If both `node_name_regex_whitelist` and `op_type_regex_whitelist` + If both `node_name_regex_allowlist` and `op_type_regex_allowlist` are set, the two filtering operations will occur in a logical `AND` relation. In other words, a node will be included if and only if it - hits both whitelists. - tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor + hits both allowlists. + tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor data type, e.g., `"^int.*"`. - This whitelist operates in logical `AND` relations to the two whitelists + This allowlist operates in logical `AND` relations to the two allowlists above. tolerate_debug_op_creation_failures: (`bool`) whether debug op creation failures (e.g., due to dtype incompatibility) are to be tolerated by not @@ -142,12 +142,14 @@ def watch_graph(run_options, if isinstance(debug_ops, str): debug_ops = [debug_ops] - node_name_pattern = (re.compile(node_name_regex_whitelist) - if node_name_regex_whitelist else None) - op_type_pattern = (re.compile(op_type_regex_whitelist) - if op_type_regex_whitelist else None) - tensor_dtype_pattern = (re.compile(tensor_dtype_regex_whitelist) - if tensor_dtype_regex_whitelist else None) + node_name_pattern = ( + re.compile(node_name_regex_allowlist) + if node_name_regex_allowlist else None) + op_type_pattern = ( + re.compile(op_type_regex_allowlist) if op_type_regex_allowlist else None) + tensor_dtype_pattern = ( + re.compile(tensor_dtype_regex_allowlist) + if tensor_dtype_regex_allowlist else None) ops = graph.get_operations() for op in ops: @@ -210,7 +212,7 @@ def watch_graph_with_blacklists(run_options, """Add debug tensor watches, blacklisting nodes and op types. This is similar to `watch_graph()`, but the node names and op types are - blacklisted, instead of whitelisted. + blacklisted, instead of allowlisted. N.B.: 1. Under certain circumstances, the `Tensor` may not get actually watched @@ -238,7 +240,7 @@ def watch_graph_with_blacklists(run_options, neither of the blacklists. tensor_dtype_regex_blacklist: Regular-expression blacklist for Tensor data type, e.g., `"^int.*"`. - This blacklist operates in logical `OR` relations to the two whitelists + This blacklist operates in logical `OR` relations to the two allowlists above. tolerate_debug_op_creation_failures: (`bool`) whether debug op creation failures (e.g., due to dtype incompatibility) are to be tolerated by not diff --git a/tensorflow/python/debug/lib/debug_utils_test.py b/tensorflow/python/debug/lib/debug_utils_test.py index c8effc8eeed..188b89debec 100644 --- a/tensorflow/python/debug/lib/debug_utils_test.py +++ b/tensorflow/python/debug/lib/debug_utils_test.py @@ -227,12 +227,12 @@ class DebugUtilsTest(test_util.TensorFlowTestCase): # Assert that the wildcard node name has been created. self.assertIn("*", node_names) - def testWatchGraph_nodeNameWhitelist(self): + def testWatchGraph_nodeNameAllowlist(self): debug_utils.watch_graph( self._run_options, self._graph, debug_urls="file:///tmp/tfdbg_1", - node_name_regex_whitelist="(a1$|a1_init$|a1/.*|p1$)") + node_name_regex_allowlist="(a1$|a1_init$|a1/.*|p1$)") node_names = self._verify_watches( self._run_options.debug_options.debug_tensor_watch_opts, 0, @@ -241,50 +241,50 @@ class DebugUtilsTest(test_util.TensorFlowTestCase): sorted(["a1_init", "a1", "a1/Assign", "a1/read", "p1"]), sorted(node_names)) - def testWatchGraph_opTypeWhitelist(self): + def testWatchGraph_opTypeAllowlist(self): debug_utils.watch_graph( self._run_options, self._graph, debug_urls="file:///tmp/tfdbg_1", - op_type_regex_whitelist="(Variable|MatMul)") + op_type_regex_allowlist="(Variable|MatMul)") node_names = self._verify_watches( self._run_options.debug_options.debug_tensor_watch_opts, 0, ["DebugIdentity"], ["file:///tmp/tfdbg_1"]) self.assertEqual(sorted(["a1", "b", "p1"]), sorted(node_names)) - def testWatchGraph_nodeNameAndOpTypeWhitelists(self): + def testWatchGraph_nodeNameAndOpTypeAllowlists(self): debug_utils.watch_graph( self._run_options, self._graph, debug_urls="file:///tmp/tfdbg_1", - node_name_regex_whitelist="([a-z]+1$)", - op_type_regex_whitelist="(MatMul)") + node_name_regex_allowlist="([a-z]+1$)", + op_type_regex_allowlist="(MatMul)") node_names = self._verify_watches( self._run_options.debug_options.debug_tensor_watch_opts, 0, ["DebugIdentity"], ["file:///tmp/tfdbg_1"]) self.assertEqual(["p1"], node_names) - def testWatchGraph_tensorDTypeWhitelist(self): + def testWatchGraph_tensorDTypeAllowlist(self): debug_utils.watch_graph( self._run_options, self._graph, debug_urls="file:///tmp/tfdbg_1", - tensor_dtype_regex_whitelist=".*_ref") + tensor_dtype_regex_allowlist=".*_ref") node_names = self._verify_watches( self._run_options.debug_options.debug_tensor_watch_opts, 0, ["DebugIdentity"], ["file:///tmp/tfdbg_1"]) self.assertItemsEqual(["a1", "a1/Assign", "b", "b/Assign"], node_names) - def testWatchGraph_nodeNameAndTensorDTypeWhitelists(self): + def testWatchGraph_nodeNameAndTensorDTypeAllowlists(self): debug_utils.watch_graph( self._run_options, self._graph, debug_urls="file:///tmp/tfdbg_1", - node_name_regex_whitelist="^a.*", - tensor_dtype_regex_whitelist=".*_ref") + node_name_regex_allowlist="^a.*", + tensor_dtype_regex_allowlist=".*_ref") node_names = self._verify_watches( self._run_options.debug_options.debug_tensor_watch_opts, 0, diff --git a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py index 0dc01748b08..4b1d1930e6d 100644 --- a/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/dist_session_debug_grpc_test.py @@ -143,7 +143,7 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase): debug_utils.watch_graph( run_options, sess.graph, - node_name_regex_whitelist=r"a", + node_name_regex_allowlist=r"a", debug_ops=["DebugIdentity"], debug_urls=[self.debug_server_url]) @@ -155,7 +155,7 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase): debug_utils.watch_graph( run_options, sess.graph, - node_name_regex_whitelist=r"p", + node_name_regex_allowlist=r"p", debug_ops=["DebugIdentity(gated_grpc=True)"], debug_urls=[self.debug_server_url]) @@ -209,8 +209,8 @@ class DistributedSessionDebugTest(test_util.TensorFlowTestCase): def watch_fn(feeds, fetch_keys): del feeds, fetch_keys return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"p") + debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"p") + sess = grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) diff --git a/tensorflow/python/debug/lib/grpc_large_data_test.py b/tensorflow/python/debug/lib/grpc_large_data_test.py index 64d5a241531..7748c3992bf 100644 --- a/tensorflow/python/debug/lib/grpc_large_data_test.py +++ b/tensorflow/python/debug/lib/grpc_large_data_test.py @@ -71,7 +71,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): del fetches, feeds return framework.WatchOptions( debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"original_u") + node_name_regex_allowlist=r"original_u") + sess = grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) self.assertAllClose(42.0, sess.run(u)) @@ -101,8 +102,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): def watch_fn(fetches, feeds): del fetches, feeds # Unused by this watch_fn. return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") + debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) sess.run(u.initializer) @@ -125,8 +126,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): def watch_fn(fetches, feeds): del fetches, feeds return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") + debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) sess.run(u.initializer) @@ -155,8 +156,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): def watch_fn(fetches, feeds): del fetches, feeds return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") + debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) sess.run(u.initializer) @@ -177,8 +178,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): def watch_fn(fetches, feeds): del fetches, feeds return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") + debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) sess.run(u.initializer) @@ -200,8 +201,8 @@ class LargeGraphAndLargeTensorsDebugTest(test_util.TensorFlowTestCase): def watch_fn(fetches, feeds): del fetches, feeds return framework.WatchOptions( - debug_ops=["DebugIdentity"], - node_name_regex_whitelist=r"u_init") + debug_ops=["DebugIdentity"], node_name_regex_allowlist=r"u_init") + sess = grpc_wrapper.GrpcDebugWrapperSession( sess, "localhost:%d" % self.debug_server_port, watch_fn=watch_fn) sess.run(u.initializer) diff --git a/tensorflow/python/debug/lib/session_debug_grpc_test.py b/tensorflow/python/debug/lib/session_debug_grpc_test.py index e80ae39828a..b8baf3a116e 100644 --- a/tensorflow/python/debug/lib/session_debug_grpc_test.py +++ b/tensorflow/python/debug/lib/session_debug_grpc_test.py @@ -207,8 +207,8 @@ class SessionDebugGrpcTest(session_debug_testlib.SessionDebugTestBase): del feeds, fetch_keys return framework.WatchOptions( debug_ops=["DebugIdentity", "DebugNumericSummary"], - node_name_regex_whitelist=r".*/read", - op_type_regex_whitelist=None, + node_name_regex_allowlist=r".*/read", + op_type_regex_allowlist=None, tolerate_debug_op_creation_failures=True) u = variables.VariableV1(2.1, name="u") diff --git a/tensorflow/python/debug/lib/source_utils.py b/tensorflow/python/debug/lib/source_utils.py index 1e9f7ee82a2..52cbfea6ee2 100644 --- a/tensorflow/python/debug/lib/source_utils.py +++ b/tensorflow/python/debug/lib/source_utils.py @@ -221,15 +221,15 @@ def annotate_source(dump, def list_source_files_against_dump(dump, - path_regex_whitelist=None, - node_name_regex_whitelist=None): + path_regex_allowlist=None, + node_name_regex_allowlist=None): """Generate a list of source files with information regarding ops and tensors. Args: dump: (`DebugDumpDir`) A `DebugDumpDir` object of which the Python graph has been loaded. - path_regex_whitelist: A regular-expression filter for source file path. - node_name_regex_whitelist: A regular-expression filter for node names. + path_regex_allowlist: A regular-expression filter for source file path. + node_name_regex_allowlist: A regular-expression filter for node names. Returns: A list of tuples regarding the Python source files involved in constructing @@ -264,10 +264,11 @@ def list_source_files_against_dump(dump, path_to_first_line = {} tensor_name_to_num_dumps = {} - path_regex = (re.compile(path_regex_whitelist) - if path_regex_whitelist else None) - node_name_regex = (re.compile(node_name_regex_whitelist) - if node_name_regex_whitelist else None) + path_regex = ( + re.compile(path_regex_allowlist) if path_regex_allowlist else None) + node_name_regex = ( + re.compile(node_name_regex_allowlist) + if node_name_regex_allowlist else None) to_skip_file_paths = set() for op in py_graph.get_operations(): diff --git a/tensorflow/python/debug/lib/source_utils_test.py b/tensorflow/python/debug/lib/source_utils_test.py index da4b9b87b7c..d0a4ecdbac4 100644 --- a/tensorflow/python/debug/lib/source_utils_test.py +++ b/tensorflow/python/debug/lib/source_utils_test.py @@ -406,7 +406,7 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase): def testGenerateSourceListWithNodeNameFilter(self): source_list = source_utils.list_source_files_against_dump( - self.dump, node_name_regex_whitelist=r"while/Add.*") + self.dump, node_name_regex_allowlist=r"while/Add.*") # Assert that the file paths are sorted. file_paths = [item[0] for item in source_list] @@ -433,8 +433,8 @@ class ListSourceAgainstDumpTest(test_util.TensorFlowTestCase): curr_file_basename = os.path.basename(self.curr_file_path) source_list = source_utils.list_source_files_against_dump( self.dump, - path_regex_whitelist=( - ".*" + curr_file_basename.replace(".", "\\.") + "$")) + path_regex_allowlist=(".*" + curr_file_basename.replace(".", "\\.") + + "$")) self.assertEqual(1, len(source_list)) (file_path, is_tf_py_library, num_nodes, num_tensors, num_dumps, diff --git a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py index 0a0b1eb018a..16b5537dd4a 100644 --- a/tensorflow/python/debug/wrappers/dumping_wrapper_test.py +++ b/tensorflow/python/debug/wrappers/dumping_wrapper_test.py @@ -169,7 +169,7 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): log_usage=False) def testDumpingWithLegacyWatchFnOnFetchesWorks(self): - """Use a watch_fn that returns different whitelists for different runs.""" + """Use a watch_fn that returns different allowlists for different runs.""" def watch_fn(fetches, feeds): del feeds @@ -240,9 +240,9 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): del fetches, feeds return framework.WatchOptions( debug_ops=["DebugIdentity", "DebugNumericSummary"], - node_name_regex_whitelist=r"^v.*", - op_type_regex_whitelist=r".*", - tensor_dtype_regex_whitelist=".*_ref") + node_name_regex_allowlist=r"^v.*", + op_type_regex_allowlist=r".*", + tensor_dtype_regex_allowlist=".*_ref") sess = dumping_wrapper.DumpingDebugWrapperSession( self.sess, @@ -288,14 +288,13 @@ class DumpingDebugWrapperSessionTest(test_util.TensorFlowTestCase): if watch_fn_state["run_counter"] % 2 == 1: # If odd-index run (1-based), watch every ref-type tensor. return framework.WatchOptions( - debug_ops="DebugIdentity", - tensor_dtype_regex_whitelist=".*_ref") + debug_ops="DebugIdentity", tensor_dtype_regex_allowlist=".*_ref") else: # If even-index run, watch nothing. return framework.WatchOptions( debug_ops="DebugIdentity", - node_name_regex_whitelist=r"^$", - op_type_regex_whitelist=r"^$") + node_name_regex_allowlist=r"^$", + op_type_regex_allowlist=r"^$") dumping_hook = hooks.DumpingDebugHook( self.session_root, watch_fn=counting_watch_fn, log_usage=False) diff --git a/tensorflow/python/debug/wrappers/framework.py b/tensorflow/python/debug/wrappers/framework.py index 9b107fe9a2b..4fc1e33d130 100644 --- a/tensorflow/python/debug/wrappers/framework.py +++ b/tensorflow/python/debug/wrappers/framework.py @@ -234,9 +234,9 @@ class OnRunStartResponse(object): action, debug_urls, debug_ops="DebugIdentity", - node_name_regex_whitelist=None, - op_type_regex_whitelist=None, - tensor_dtype_regex_whitelist=None, + node_name_regex_allowlist=None, + op_type_regex_allowlist=None, + tensor_dtype_regex_allowlist=None, tolerate_debug_op_creation_failures=False): """Constructor of `OnRunStartResponse`. @@ -247,10 +247,10 @@ class OnRunStartResponse(object): during the run() call. debug_ops: (`str` or `list` of `str`) Debug op(s) to be used by the debugger. - node_name_regex_whitelist: Regular-expression whitelist for node + node_name_regex_allowlist: Regular-expression allowlist for node name. - op_type_regex_whitelist: Regular-expression whitelist for op type. - tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor + op_type_regex_allowlist: Regular-expression allowlist for op type. + tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor dtype. tolerate_debug_op_creation_failures: Whether debug op creation failures are to be tolerated. @@ -264,9 +264,9 @@ class OnRunStartResponse(object): self.debug_ops = debug_ops - self.node_name_regex_whitelist = node_name_regex_whitelist - self.op_type_regex_whitelist = op_type_regex_whitelist - self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist + self.node_name_regex_allowlist = node_name_regex_allowlist + self.op_type_regex_allowlist = op_type_regex_allowlist + self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist self.tolerate_debug_op_creation_failures = ( tolerate_debug_op_creation_failures) @@ -329,7 +329,7 @@ class BaseDebugWrapperSession(session.SessionInterface): Args: sess: An (unwrapped) TensorFlow session instance. It should be a subtype of `BaseSession` or `tf.MonitoredSession`. - thread_name_filter: Regular-expression filter (whitelist) for name(s) of + thread_name_filter: Regular-expression filter (allowlist) for name(s) of thread(s) on which the wrapper session will be active. This regular expression is used in a start-anchored fashion on the thread name, i.e., by applying the `match` method of the compiled pattern. The default @@ -545,11 +545,10 @@ class BaseDebugWrapperSession(session.SessionInterface): decorated_run_options, run_start_resp.debug_urls, debug_ops=run_start_resp.debug_ops, - node_name_regex_whitelist=( - run_start_resp.node_name_regex_whitelist), - op_type_regex_whitelist=run_start_resp.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=( - run_start_resp.tensor_dtype_regex_whitelist), + node_name_regex_allowlist=(run_start_resp.node_name_regex_allowlist), + op_type_regex_allowlist=run_start_resp.op_type_regex_allowlist, + tensor_dtype_regex_allowlist=( + run_start_resp.tensor_dtype_regex_allowlist), tolerate_debug_op_creation_failures=( run_start_resp.tolerate_debug_op_creation_failures)) @@ -707,9 +706,9 @@ class BaseDebugWrapperSession(session.SessionInterface): run_options, debug_urls, debug_ops="DebugIdentity", - node_name_regex_whitelist=None, - op_type_regex_whitelist=None, - tensor_dtype_regex_whitelist=None, + node_name_regex_allowlist=None, + op_type_regex_allowlist=None, + tensor_dtype_regex_allowlist=None, tolerate_debug_op_creation_failures=False): """Modify a RunOptions object for debug tensor watching. @@ -721,10 +720,10 @@ class BaseDebugWrapperSession(session.SessionInterface): debug_urls: (list of str) debug URLs to be entered in run_options. debug_tensor_watch_opts. debug_ops: (str or list of str) debug op(s) to be used by the debugger. - node_name_regex_whitelist: Regular-expression whitelist for node + node_name_regex_allowlist: Regular-expression allowlist for node name. - op_type_regex_whitelist: Regular-expression whitelist for op type. - tensor_dtype_regex_whitelist: Regular-expression whitelist for tensor + op_type_regex_allowlist: Regular-expression allowlist for op type. + tensor_dtype_regex_allowlist: Regular-expression allowlist for tensor dtype. tolerate_debug_op_creation_failures: Whether debug op creation failures are to be tolerated. @@ -736,9 +735,9 @@ class BaseDebugWrapperSession(session.SessionInterface): self._sess.graph, debug_urls=debug_urls, debug_ops=debug_ops, - node_name_regex_whitelist=node_name_regex_whitelist, - op_type_regex_whitelist=op_type_regex_whitelist, - tensor_dtype_regex_whitelist=tensor_dtype_regex_whitelist, + node_name_regex_allowlist=node_name_regex_allowlist, + op_type_regex_allowlist=op_type_regex_allowlist, + tensor_dtype_regex_allowlist=tensor_dtype_regex_allowlist, tolerate_debug_op_creation_failures=tolerate_debug_op_creation_failures, reset_disk_byte_usage=(self._run_call_count == 1 or self._is_disk_usage_reset_each_run())) @@ -821,8 +820,8 @@ class BaseDebugWrapperSession(session.SessionInterface): def close(self): self._sess.close() - # TODO(cais): Add _node_name_regex_whitelist and - # _node_op_type_regex_whitelist. + # TODO(cais): Add _node_name_regex_allowlist and + # _node_op_type_regex_allowlist. def should_stop(self): if hasattr(self._sess, "should_stop"): @@ -838,9 +837,9 @@ class WatchOptions(object): def __init__(self, debug_ops=None, - node_name_regex_whitelist=None, - op_type_regex_whitelist=None, - tensor_dtype_regex_whitelist=None, + node_name_regex_allowlist=None, + op_type_regex_allowlist=None, + tensor_dtype_regex_allowlist=None, tolerate_debug_op_creation_failures=False): """Constructor of WatchOptions: Debug watch options. @@ -848,17 +847,17 @@ class WatchOptions(object): Args: debug_ops: (`str` or `list of str`) Debug ops to be used. - node_name_regex_whitelist: Regular-expression whitelist for node_name, + node_name_regex_allowlist: Regular-expression allowlist for node_name, e.g., `"(weight_[0-9]+|bias_.*)"` - op_type_regex_whitelist: Regular-expression whitelist for the op type of + op_type_regex_allowlist: Regular-expression allowlist for the op type of nodes, e.g., `"(Variable|Add)"`. - If both `node_name_regex_whitelist` and `op_type_regex_whitelist` + If both `node_name_regex_allowlist` and `op_type_regex_allowlist` are set, the two filtering operations will occur in a logical `AND` relation. In other words, a node will be included if and only if it - hits both whitelists. - tensor_dtype_regex_whitelist: Regular-expression whitelist for Tensor + hits both allowlists. + tensor_dtype_regex_allowlist: Regular-expression allowlist for Tensor data type, e.g., `"^int.*"`. - This whitelist operates in logical `AND` relations to the two whitelists + This allowlist operates in logical `AND` relations to the two allowlists above. tolerate_debug_op_creation_failures: (`bool`) whether debug op creation failures (e.g., due to dtype incompatibility) are to be tolerated by not @@ -868,19 +867,19 @@ class WatchOptions(object): self.debug_ops = debug_ops else: self.debug_ops = ["DebugIdentity"] - self.node_name_regex_whitelist = node_name_regex_whitelist - self.op_type_regex_whitelist = op_type_regex_whitelist - self.tensor_dtype_regex_whitelist = tensor_dtype_regex_whitelist + self.node_name_regex_allowlist = node_name_regex_allowlist + self.op_type_regex_allowlist = op_type_regex_allowlist + self.tensor_dtype_regex_allowlist = tensor_dtype_regex_allowlist self.tolerate_debug_op_creation_failures = ( tolerate_debug_op_creation_failures) def __repr__(self): - return ("WatchOptions(debug_ops=%r, node_name_regex_whitelist=%r, " - "op_type_regex_whitelist=%r, tensor_dtype_regex_whitelist=%r, " - "tolerate_debug_op_creation_failures=%r)" % ( - self.debug_ops, self.node_name_regex_whitelist, - self.op_type_regex_whitelist, self.tensor_dtype_regex_whitelist, - self.tolerate_debug_op_creation_failures)) + return ("WatchOptions(debug_ops=%r, node_name_regex_allowlist=%r, " + "op_type_regex_allowlist=%r, tensor_dtype_regex_allowlist=%r, " + "tolerate_debug_op_creation_failures=%r)" % + (self.debug_ops, self.node_name_regex_allowlist, + self.op_type_regex_allowlist, self.tensor_dtype_regex_allowlist, + self.tolerate_debug_op_creation_failures)) class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): @@ -952,14 +951,14 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): OnRunStartAction.DEBUG_RUN, debug_urls, debug_ops=watch_opts.debug_ops, - node_name_regex_whitelist=watch_opts.node_name_regex_whitelist, - op_type_regex_whitelist=watch_opts.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=watch_opts.tensor_dtype_regex_whitelist, + node_name_regex_allowlist=watch_opts.node_name_regex_allowlist, + op_type_regex_allowlist=watch_opts.op_type_regex_allowlist, + tensor_dtype_regex_allowlist=watch_opts.tensor_dtype_regex_allowlist, tolerate_debug_op_creation_failures=( watch_opts.tolerate_debug_op_creation_failures)) def _prepare_run_watch_config(self, fetches, feed_dict): - """Get the debug_urls, and node/op whitelists for the current run() call. + """Get the debug_urls, and node/op allowlists for the current run() call. Args: fetches: Same as the `fetches` argument to `Session.run()`. @@ -969,7 +968,7 @@ class NonInteractiveDebugWrapperSession(BaseDebugWrapperSession): debug_urls: (str or list of str) Debug URLs for the current run() call. Currently, the list consists of only one URL that is a file:// URL. watch_options: (WatchOptions) The return value of a watch_fn, containing - options including debug_ops, and whitelists. + options including debug_ops, and allowlists. """ debug_urls = self.prepare_run_debug_urls(fetches, feed_dict) diff --git a/tensorflow/python/debug/wrappers/hooks.py b/tensorflow/python/debug/wrappers/hooks.py index 4c958be257c..2106fcc7492 100644 --- a/tensorflow/python/debug/wrappers/hooks.py +++ b/tensorflow/python/debug/wrappers/hooks.py @@ -124,12 +124,12 @@ class LocalCLIDebugHook(session_run_hook.SessionRunHook): run_args.options, on_run_start_response.debug_urls, debug_ops=on_run_start_response.debug_ops, - node_name_regex_whitelist=( - on_run_start_response.node_name_regex_whitelist), - op_type_regex_whitelist=( - on_run_start_response.op_type_regex_whitelist), - tensor_dtype_regex_whitelist=( - on_run_start_response.tensor_dtype_regex_whitelist), + node_name_regex_allowlist=( + on_run_start_response.node_name_regex_allowlist), + op_type_regex_allowlist=( + on_run_start_response.op_type_regex_allowlist), + tensor_dtype_regex_allowlist=( + on_run_start_response.tensor_dtype_regex_allowlist), tolerate_debug_op_creation_failures=( on_run_start_response.tolerate_debug_op_creation_failures)) # pylint: enable=protected-access @@ -205,9 +205,9 @@ class DumpingDebugHook(session_run_hook.SessionRunHook): run_context.session.graph, debug_urls=debug_urls, debug_ops=watch_options.debug_ops, - node_name_regex_whitelist=watch_options.node_name_regex_whitelist, - op_type_regex_whitelist=watch_options.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist, + node_name_regex_allowlist=watch_options.node_name_regex_allowlist, + op_type_regex_allowlist=watch_options.op_type_regex_allowlist, + tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist, tolerate_debug_op_creation_failures=( watch_options.tolerate_debug_op_creation_failures), reset_disk_byte_usage=reset_disk_byte_usage) @@ -292,9 +292,9 @@ class GrpcDebugHook(session_run_hook.SessionRunHook): debug_urls=self._grpc_debug_wrapper_session.prepare_run_debug_urls( fetches, feed_dict), debug_ops=watch_options.debug_ops, - node_name_regex_whitelist=watch_options.node_name_regex_whitelist, - op_type_regex_whitelist=watch_options.op_type_regex_whitelist, - tensor_dtype_regex_whitelist=watch_options.tensor_dtype_regex_whitelist, + node_name_regex_allowlist=watch_options.node_name_regex_allowlist, + op_type_regex_allowlist=watch_options.op_type_regex_allowlist, + tensor_dtype_regex_allowlist=watch_options.tensor_dtype_regex_allowlist, tolerate_debug_op_creation_failures=( watch_options.tolerate_debug_op_creation_failures)) diff --git a/tensorflow/python/debug/wrappers/local_cli_wrapper.py b/tensorflow/python/debug/wrappers/local_cli_wrapper.py index 0d8c71396f0..4069bdf1f3f 100644 --- a/tensorflow/python/debug/wrappers/local_cli_wrapper.py +++ b/tensorflow/python/debug/wrappers/local_cli_wrapper.py @@ -552,9 +552,9 @@ class LocalCLIDebugWrapperSession(framework.BaseDebugWrapperSession): run_start_response = framework.OnRunStartResponse( action, debug_urls, - node_name_regex_whitelist=parsed.node_name_filter, - op_type_regex_whitelist=parsed.op_type_filter, - tensor_dtype_regex_whitelist=parsed.tensor_dtype_filter) + node_name_regex_allowlist=parsed.node_name_filter, + op_type_regex_allowlist=parsed.op_type_filter, + tensor_dtype_regex_allowlist=parsed.tensor_dtype_filter) if parsed.till_filter_pass: # For the run-till-filter-pass (run -f) mode, use the DEBUG_RUN diff --git a/tensorflow/python/distribute/mirrored_run.py b/tensorflow/python/distribute/mirrored_run.py index c0438d4fd12..05018450121 100644 --- a/tensorflow/python/distribute/mirrored_run.py +++ b/tensorflow/python/distribute/mirrored_run.py @@ -88,7 +88,7 @@ def call_for_each_replica(strategy, fn, args=None, kwargs=None): else: # When a tf.function is wrapped to trigger _call_for_each_replica (see # the other branch above), AutoGraph stops conversion at - # _call_for_each_replica itself (TF library functions are whitelisted). + # _call_for_each_replica itself (TF library functions are allowlisted). # This makes sure that the Python function that originally passed to # the tf.function is still converted. fn = autograph.tf_convert(fn, autograph_ctx.control_status_ctx()) diff --git a/tensorflow/python/eager/function.py b/tensorflow/python/eager/function.py index dc75ca13645..3ce6d300bc4 100644 --- a/tensorflow/python/eager/function.py +++ b/tensorflow/python/eager/function.py @@ -237,7 +237,7 @@ def _parse_func_attrs(attributes): A dict of attributes where the key is the name of attribute and the value is the AttrValue proto. Raises: - ValueError: If the kwargs contains unwhitelisted name or unsupported value + ValueError: If the kwargs contains unallowlisted name or unsupported value types. """ attrs = {} @@ -3625,9 +3625,9 @@ def defun_with_attributes(func=None, input_signature: same as defun()'s input_signature. attributes: A dictionary of arguments which will be added to function def as attributes. Currently only support primitive types as value, and only - whitelisted attribute name is allowed. Unwhitelisted attribute name or + allowlisted attribute name is allowed. Unallowlisted attribute name or unsupported value will result into ValueError. `func_name` is also one of - the whitelisted argument which is a python string, and sets the name for + the allowlisted argument which is a python string, and sets the name for this `ConcreteFunction` in the graph. autograph: same as defun()'s autograph. experimental_autograph_options: same as defun()'s diff --git a/tensorflow/python/framework/auto_control_deps.py b/tensorflow/python/framework/auto_control_deps.py index 4b47735e0bf..7bc92936cb0 100644 --- a/tensorflow/python/framework/auto_control_deps.py +++ b/tensorflow/python/framework/auto_control_deps.py @@ -108,9 +108,9 @@ _ALL_BLACKLISTED_OPS = ( set(ASYNC_STATEFUL_OPS) | set(LEGACY_RANDOM_OPS) | set(_ORDER_INSENSITIVE_STATEFUL_OPS)) -# Op types that are marked as stateless, but should be whitelisted to add auto +# Op types that are marked as stateless, but should be allowlisted to add auto # control dependencies. -_WHITELIST_STATELESS_OPS = [ +_ALLOWLIST_STATELESS_OPS = [ # As TPU collective ops are blocking, if there are more than one collective # op in the function, we need to make sure different collectives ops are # scheduled in certain orders. Otherwise if at the same time all the @@ -125,7 +125,7 @@ _WHITELIST_STATELESS_OPS = [ def op_is_stateful(op): # pylint: disable=protected-access return (op._is_stateful and op.type not in _ALL_BLACKLISTED_OPS) or ( - op.type in _WHITELIST_STATELESS_OPS) + op.type in _ALLOWLIST_STATELESS_OPS) class ResourceType(enum.Enum): diff --git a/tensorflow/python/framework/convert_to_constants.py b/tensorflow/python/framework/convert_to_constants.py index 4c3cbb06bf1..555004e0836 100644 --- a/tensorflow/python/framework/convert_to_constants.py +++ b/tensorflow/python/framework/convert_to_constants.py @@ -710,12 +710,12 @@ class _ConverterData(object): def __init__(self, graph_def, - variable_names_whitelist=None, + variable_names_allowlist=None, variable_names_blacklist=None): self._graph_def = graph_def self._tensor_data = {} self._build_node_defs_list() - self._variable_names_whitelist = variable_names_whitelist + self._variable_names_allowlist = variable_names_allowlist self._variable_names_blacklist = variable_names_blacklist @property @@ -740,8 +740,8 @@ class _ConverterData(object): def _should_convert(self, name): """Checks whether to convert the given variable name to a constant.""" - return (self._variable_names_whitelist is None or - name in self._variable_names_whitelist) and ( + return (self._variable_names_allowlist is None or + name in self._variable_names_allowlist) and ( self._variable_names_blacklist is None or name not in self._variable_names_blacklist) @@ -776,7 +776,7 @@ class _FunctionConverterData(_ConverterData): func, lower_control_flow, aggressive_inlining, - variable_names_whitelist=None, + variable_names_allowlist=None, variable_names_blacklist=None): """Creates the conversion data for the given function. @@ -787,7 +787,7 @@ class _FunctionConverterData(_ConverterData): aggressive_inlining: Boolean indicating whether or not to to aggressive function inlining (might be unsafe if function has stateful ops, not properly connected to control outputs). - variable_names_whitelist: The set of variable names to convert (by + variable_names_allowlist: The set of variable names to convert (by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants. @@ -799,7 +799,7 @@ class _FunctionConverterData(_ConverterData): aggressive_inlining) super(_FunctionConverterData, self).__init__( graph_def, - variable_names_whitelist=variable_names_whitelist, + variable_names_allowlist=variable_names_allowlist, variable_names_blacklist=variable_names_blacklist) self._build_tensor_data() @@ -849,12 +849,12 @@ class _SessionConverterData(_ConverterData): session, graph_def, output_node_names, - variable_names_whitelist=None, + variable_names_allowlist=None, variable_names_blacklist=None): graph_def = graph_util.extract_sub_graph(graph_def, output_node_names) super(_SessionConverterData, self).__init__( graph_def, - variable_names_whitelist=variable_names_whitelist, + variable_names_allowlist=variable_names_allowlist, variable_names_blacklist=variable_names_blacklist) nodes_to_convert = [] @@ -1114,7 +1114,7 @@ def convert_variables_to_constants_from_session_graph( session, graph_def, output_node_names, - variable_names_whitelist=None, + variable_names_allowlist=None, variable_names_blacklist=None): """Replaces all the variables in a graph with constants of the same values. @@ -1129,7 +1129,7 @@ def convert_variables_to_constants_from_session_graph( session: Active TensorFlow session containing the variables. graph_def: A GraphDef to convert. output_node_names: List of name strings for the result nodes of the graph. - variable_names_whitelist: The set of variable names to convert (by default, + variable_names_allowlist: The set of variable names to convert (by default, all variables are converted). variable_names_blacklist: The set of variable names to omit converting to constants. @@ -1142,6 +1142,6 @@ def convert_variables_to_constants_from_session_graph( session=session, graph_def=graph_def, output_node_names=output_node_names, - variable_names_whitelist=variable_names_whitelist, + variable_names_allowlist=variable_names_allowlist, variable_names_blacklist=variable_names_blacklist)) return graph_def diff --git a/tensorflow/python/framework/func_graph.py b/tensorflow/python/framework/func_graph.py index e8e8fcbf081..55508c4803b 100644 --- a/tensorflow/python/framework/func_graph.py +++ b/tensorflow/python/framework/func_graph.py @@ -49,7 +49,7 @@ from tensorflow.python.util import object_identity from tensorflow.python.util import tf_contextlib from tensorflow.python.util import tf_decorator -WHITELIST_COLLECTIONS = [ +ALLOWLIST_COLLECTIONS = [ ops.GraphKeys.GLOBAL_VARIABLES, ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.TRAINABLE_VARIABLES, @@ -172,9 +172,9 @@ class FuncGraph(ops.Graph): name: the name of the function. collections: a dictionary of collections this FuncGraph should start with. If not specified (None), the FuncGraph will read (but not write - to) the outer graph's collections that are not whitelisted, and both - read and write to the outer graph's collections that are whitelisted. - The current whitelisted collections are the global variables, the + to) the outer graph's collections that are not allowlisted, and both + read and write to the outer graph's collections that are allowlisted. + The current allowlisted collections are the global variables, the local variables, and the trainable variables. Defaults to None. capture_by_value: An optional boolean. If True, the func graph will @@ -241,10 +241,10 @@ class FuncGraph(ops.Graph): if collections is None: for collection_name in graph.get_all_collection_keys(): - if collection_name not in WHITELIST_COLLECTIONS: + if collection_name not in ALLOWLIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection( collection_name) - for collection_name in WHITELIST_COLLECTIONS: + for collection_name in ALLOWLIST_COLLECTIONS: self._collections[collection_name] = graph.get_collection_ref( collection_name) else: @@ -842,9 +842,9 @@ def func_graph_from_py_func(name, set, returning an Operation triggers an error. collections: a dictionary of collections this FuncGraph should start with. If not specified (None), the FuncGraph will read (but not write to) - the outer graph's collections that are not whitelisted, and both - read and write to the outer graph's collections that are whitelisted. - The current whitelisted collections are the global variables, the + the outer graph's collections that are not allowlisted, and both + read and write to the outer graph's collections that are allowlisted. + The current allowlisted collections are the global variables, the local variables, and the trainable variables. Defaults to None. capture_by_value: An optional boolean. If True, the func graph will capture diff --git a/tensorflow/python/framework/function.py b/tensorflow/python/framework/function.py index 3f36918039c..c2bffbeecc7 100644 --- a/tensorflow/python/framework/function.py +++ b/tensorflow/python/framework/function.py @@ -234,7 +234,7 @@ class _DefinedFunction(object): out_names=None, shape_func=None, capture_by_value=False, - whitelisted_stateful_ops=None, + allowlisted_stateful_ops=None, capture_resource_var_by_value=True, **kwargs): """Creates _DefinedFunction. @@ -256,7 +256,7 @@ class _DefinedFunction(object): output shapes. capture_by_value: Boolean (defaults to False). If True, captured values will be copied into the function body. - whitelisted_stateful_ops: A set of ops that if stateful we ignore and + allowlisted_stateful_ops: A set of ops that if stateful we ignore and copy into the function body, when `capture_by_value` is True. capture_resource_var_by_value: Boolean (defaults to True). If False, captured resource variable returns the handle instead of value. @@ -275,9 +275,9 @@ class _DefinedFunction(object): self._out_names = out_names self._shape_func = shape_func self._capture_by_value = capture_by_value - self._whitelisted_stateful_ops = whitelisted_stateful_ops - if self._whitelisted_stateful_ops is None: - self._whitelisted_stateful_ops = set() + self._allowlisted_stateful_ops = allowlisted_stateful_ops + if self._allowlisted_stateful_ops is None: + self._allowlisted_stateful_ops = set() self._capture_resource_var_by_value = capture_resource_var_by_value self._extra_kwargs = kwargs # Constructed only when C API is disabled, lazily @@ -403,7 +403,7 @@ class _DefinedFunction(object): self._capture_by_value, self._caller_device, collections_ref=collections_ref, - whitelisted_stateful_ops=self._whitelisted_stateful_ops, + allowlisted_stateful_ops=self._allowlisted_stateful_ops, capture_resource_var_by_value=self._capture_resource_var_by_value) self._extra_inputs = temp_graph.extra_inputs @@ -690,11 +690,11 @@ class _FuncGraph(ops.Graph): function argument and the caller passes in the captured tensor. """ - def __init__(self, name, capture_by_value, whitelisted_stateful_ops, + def __init__(self, name, capture_by_value, allowlisted_stateful_ops, capture_resource_var_by_value, *args, **kwargs): super(_FuncGraph, self).__init__(*args, **kwargs) self._capture_by_value = capture_by_value - self._whitelisted_stateful_ops = whitelisted_stateful_ops + self._allowlisted_stateful_ops = allowlisted_stateful_ops self._capture_resource_var_by_value = capture_resource_var_by_value self._building_function = True self._outer_graph = ops.get_default_graph() @@ -879,7 +879,7 @@ class _FuncGraph(ops.Graph): def _add_op_and_parents(self, op): # pylint: disable=protected-access op_def = graph_to_function_def._get_op_def(op) - if op._is_stateful and op not in self._whitelisted_stateful_ops: + if op._is_stateful and op not in self._allowlisted_stateful_ops: raise ValueError("Cannot capture a stateful node (name:%s, type:%s) " "by value." % (op.name, op.type)) elif op.type in ("Placeholder", "PlaceholderV2"): @@ -912,7 +912,7 @@ def func_graph_from_py_func(func, container=None, collections_ref=None, arg_shapes=None, - whitelisted_stateful_ops=None, + allowlisted_stateful_ops=None, capture_resource_var_by_value=True): """Returns a _FuncGraph generated from `func`. @@ -931,7 +931,7 @@ def func_graph_from_py_func(func, collections_ref: A reference to a collections dict the _FuncGraph should use internally. arg_shapes: A sequence of the function's argument shapes. - whitelisted_stateful_ops: A set of ops that if stateful we ignore and + allowlisted_stateful_ops: A set of ops that if stateful we ignore and re-create. capture_resource_var_by_value: Boolean (defaults to True). If False, captured resource variable returns the handle instead of value. @@ -944,7 +944,7 @@ def func_graph_from_py_func(func, """ if not name: name = function_utils.get_func_name(func) - func_graph = _FuncGraph(name, capture_by_value, whitelisted_stateful_ops, + func_graph = _FuncGraph(name, capture_by_value, allowlisted_stateful_ops, capture_resource_var_by_value) with func_graph.as_default(), ops.device(device): diff --git a/tensorflow/python/framework/function_test.py b/tensorflow/python/framework/function_test.py index 26ae88e58c7..16b2c7c5048 100644 --- a/tensorflow/python/framework/function_test.py +++ b/tensorflow/python/framework/function_test.py @@ -1043,7 +1043,7 @@ class FunctionTest(test.TestCase): self.assertFalse(all(val4 == val2)) @test_util.run_v1_only("currently failing on v2") - def testStatefulFunctionWithWhitelisting(self): + def testStatefulFunctionWithAllowlisting(self): t = random_ops.random_uniform([100], maxval=10, dtype=dtypes.int32) @function.Defun(capture_by_value=True) @@ -1054,8 +1054,8 @@ class FunctionTest(test.TestCase): with self.assertRaisesRegex(ValueError, "Cannot capture a stateful node"): res = StatefulFn() - # This time we whitelist this op, so that its recreated. - @function.Defun(capture_by_value=True, whitelisted_stateful_ops=set([t.op])) + # This time we allowlist this op, so that its recreated. + @function.Defun(capture_by_value=True, allowlisted_stateful_ops=set([t.op])) def StatefulFn2(): return t + constant_op.constant(3, dtype=dtypes.int32) diff --git a/tensorflow/python/framework/graph_util_impl.py b/tensorflow/python/framework/graph_util_impl.py index 3cc28d0a707..753584813f9 100644 --- a/tensorflow/python/framework/graph_util_impl.py +++ b/tensorflow/python/framework/graph_util_impl.py @@ -276,7 +276,7 @@ def convert_variables_to_constants(sess, session=sess, graph_def=input_graph_def, output_node_names=output_node_names, - variable_names_whitelist=variable_names_whitelist, + variable_names_allowlist=variable_names_whitelist, variable_names_blacklist=variable_names_blacklist) # The previous code logic generated an empty versions field, we clear it here # to maintain backwards compatibility. diff --git a/tensorflow/python/framework/importer_test.py b/tensorflow/python/framework/importer_test.py index 9d64311b4b1..de295955c78 100644 --- a/tensorflow/python/framework/importer_test.py +++ b/tensorflow/python/framework/importer_test.py @@ -472,7 +472,7 @@ class ImportGraphDefTest(test.TestCase): node { name: 'B' op: 'FloatInput' input: 'A:0' } """)) - def testShapeWhitelistViolation(self): + def testShapeAllowlistViolation(self): # L2 loss produces a scalar shape, but the graph # has the wrong shape, so raise an error. with ops.Graph().as_default(): diff --git a/tensorflow/python/framework/python_op_gen.cc b/tensorflow/python/framework/python_op_gen.cc index 5b9665cf5fc..8a3c940a566 100644 --- a/tensorflow/python/framework/python_op_gen.cc +++ b/tensorflow/python/framework/python_op_gen.cc @@ -351,7 +351,7 @@ string GenEagerPythonOp::Code() { } std::unordered_map type_annotations; - // Only populate map for whitelisted ops + // Only populate map for allowlisted ops if (add_type_annotations_) { type_annotations = GetTypeAnnotations(); } diff --git a/tensorflow/python/framework/python_op_gen_main.cc b/tensorflow/python/framework/python_op_gen_main.cc index 08f546e8dd8..1cf6ad6e0e4 100644 --- a/tensorflow/python/framework/python_op_gen_main.cc +++ b/tensorflow/python/framework/python_op_gen_main.cc @@ -108,7 +108,7 @@ string InferSourceFileName(const char* argv_zero) { void PrintAllPythonOps(const std::vector& op_list, const std::vector& api_def_dirs, const string& source_file_name, - bool op_list_is_whitelist, + bool op_list_is_allowlist, const std::unordered_set type_annotate_ops) { OpList ops; OpRegistry::Global()->Export(false, &ops); @@ -126,11 +126,11 @@ void PrintAllPythonOps(const std::vector& op_list, api_def_map.UpdateDocs(); } - if (op_list_is_whitelist) { - std::unordered_set whitelist(op_list.begin(), op_list.end()); + if (op_list_is_allowlist) { + std::unordered_set allowlist(op_list.begin(), op_list.end()); OpList pruned_ops; for (const auto& op_def : ops.op()) { - if (whitelist.find(op_def.name()) != whitelist.end()) { + if (allowlist.find(op_def.name()) != allowlist.end()) { *pruned_ops.mutable_op()->Add() = op_def; } } @@ -165,13 +165,13 @@ int main(int argc, char* argv[]) { if (argc == 2) { tensorflow::PrintAllPythonOps({}, api_def_dirs, source_file_name, - false /* op_list_is_whitelist */, + false /* op_list_is_allowlist */, type_annotate_ops); } else if (argc == 3) { std::vector hidden_ops; TF_CHECK_OK(tensorflow::ParseOpListCommandLine(argv[2], &hidden_ops)); tensorflow::PrintAllPythonOps(hidden_ops, api_def_dirs, source_file_name, - false /* op_list_is_whitelist */, + false /* op_list_is_allowlist */, type_annotate_ops); } else if (argc == 4) { std::vector op_list; diff --git a/tensorflow/python/grappler/auto_mixed_precision_test.py b/tensorflow/python/grappler/auto_mixed_precision_test.py index f2b0985b348..539c2bca9f3 100644 --- a/tensorflow/python/grappler/auto_mixed_precision_test.py +++ b/tensorflow/python/grappler/auto_mixed_precision_test.py @@ -201,7 +201,7 @@ def _recurrent_lstm(c, h): def _make_node_with_color(color, input_tensor, name=None): """Returns a node representative of the specified list type.""" color = color.lower() - if color == 'w': # White node + if color == 'w': # Allow node weights = _weight(input_tensor.get_shape().as_list()) return math_ops.matmul(input_tensor, weights, name=name) if color == 'g': # Gray node @@ -371,7 +371,7 @@ class AutoMixedPrecisionTest(test.TestCase, parameterized.TestCase): The loop has different node colors in different sections of the graph. The arguments must be strings where each character represents the color of a - node in that section of the graph: w = white, g = gray, c = clear, + node in that section of the graph: w = allow, g = gray, c = clear, b = black. CAPITALIZED characters indicate that the node is expected to be changed to DT_HALF during graph optimization. diff --git a/tensorflow/python/keras/engine/training_utils.py b/tensorflow/python/keras/engine/training_utils.py index 4a351668444..3e0735ceec4 100644 --- a/tensorflow/python/keras/engine/training_utils.py +++ b/tensorflow/python/keras/engine/training_utils.py @@ -1594,7 +1594,7 @@ def assert_not_batched(dataset): if isinstance(dataset, dataset_ops.DatasetV1Adapter): return assert_not_batched(dataset._dataset) else: - whitelisted_types = [ + allowed_types = [ dataset_ops._OptionsDataset, dataset_ops.ConcatenateDataset, dataset_ops.CacheDataset, @@ -1615,7 +1615,7 @@ def assert_not_batched(dataset): readers.TextLineDatasetV2, readers.TFRecordDatasetV2, ] - for ty in whitelisted_types: + for ty in allowed_types: if isinstance(dataset, ty): for input_dataset in dataset._inputs(): assert_not_batched(input_dataset) @@ -1649,7 +1649,7 @@ def assert_not_shuffled(dataset): if isinstance(dataset, dataset_ops.DatasetV1Adapter): return assert_not_shuffled(dataset._dataset) else: - whitelisted_types = [ + allowed_types = [ dataset_ops._OptionsDataset, dataset_ops.BatchDataset, dataset_ops.ConcatenateDataset, @@ -1672,7 +1672,7 @@ def assert_not_shuffled(dataset): readers.TextLineDatasetV2, readers.TFRecordDatasetV2, ] - for ty in whitelisted_types: + for ty in allowed_types: if isinstance(dataset, ty): for input_dataset in dataset._inputs(): assert_not_shuffled(input_dataset) diff --git a/tensorflow/python/keras/engine/training_v1.py b/tensorflow/python/keras/engine/training_v1.py index f30df7b2790..bf518e1e702 100644 --- a/tensorflow/python/keras/engine/training_v1.py +++ b/tensorflow/python/keras/engine/training_v1.py @@ -2858,7 +2858,7 @@ class DistributedCallbackModel(Model): orig_model_weights) def __getattr__(self, item): - # Whitelisted attributes of the model that can be accessed by the user + # Allowed attributes of the model that can be accessed by the user # during a callback. if item not in ('_setattr_tracking', '_layers'): logging.warning('You are accessing attribute ' + item + ' of the ' diff --git a/tensorflow/python/keras/layers/wrappers_test.py b/tensorflow/python/keras/layers/wrappers_test.py index 75d951a4a7a..efff254c688 100644 --- a/tensorflow/python/keras/layers/wrappers_test.py +++ b/tensorflow/python/keras/layers/wrappers_test.py @@ -333,7 +333,7 @@ class TimeDistributedTest(keras_parameterized.TestCase): keras.layers.RNN(keras.layers.SimpleRNNCell(10), stateful=True)) self.assertFalse(td2._always_use_reshape) - # Custom layers are not whitelisted for the fast reshape implementation. + # Custom layers are not allowlisted for the fast reshape implementation. td3 = keras.layers.TimeDistributed(NoReshapeLayer()) self.assertFalse(td3._always_use_reshape) diff --git a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py index 05f4b1d17f3..4479c378638 100644 --- a/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py +++ b/tensorflow/python/keras/optimizer_v2/optimizer_v2_test.py @@ -898,7 +898,7 @@ class OptimizerWithFunctionTest(test.TestCase): _NUM_LEARNERS = 50 APPLY_SCOPE = 'debug_apply' -WHITELIST = [ +ALLOWLIST = [ # optimizer_v2._deduplicate_indexed_slices contains an indexed slice: # array_ops.shape(unique_indices)[0] # which winds up expanding to [0:1:1] thereby creating three constants @@ -1025,8 +1025,8 @@ def identify_redundant_ops(graph): # Certain ops are simply not worth eliminating, and are instead simply # ignored. name, op_type = op_defs[0].name, op_defs[0].type - if any(whitelisted_scope in name and op_type == whitelisted_type - for whitelisted_scope, whitelisted_type in WHITELIST): + if any(allowlisted_scope in name and op_type == allowlisted_type + for allowlisted_scope, allowlisted_type in ALLOWLIST): continue num_duplicates += len(op_defs) diff --git a/tensorflow/python/keras/preprocessing/dataset_utils.py b/tensorflow/python/keras/preprocessing/dataset_utils.py index bc65c7b9b99..1c9d283c2f1 100644 --- a/tensorflow/python/keras/preprocessing/dataset_utils.py +++ b/tensorflow/python/keras/preprocessing/dataset_utils.py @@ -45,7 +45,7 @@ def index_directory(directory, valid files found in the directory. Labels should be sorted according to the alphanumeric order of the image file paths (obtained via `os.walk(directory)` in Python). - formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt"). + formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). class_names: Only valid if "labels" is "inferred". This is the explict list of class names (must match names of subdirectories). Used to control the order of the classes @@ -136,7 +136,7 @@ def index_subdirectory(directory, class_indices, follow_links, formats): class_indices: dict mapping class names to their index. follow_links: boolean, whether to recursively follow subdirectories (if False, we only list top-level images in `directory`). - formats: Whitelist of file extensions to index (e.g. ".jpg", ".txt"). + formats: Allowlist of file extensions to index (e.g. ".jpg", ".txt"). Returns: tuple `(filenames, labels)`. `filenames` is a list of relative file diff --git a/tensorflow/python/keras/preprocessing/image_dataset.py b/tensorflow/python/keras/preprocessing/image_dataset.py index 7c7df083559..892fbe59709 100644 --- a/tensorflow/python/keras/preprocessing/image_dataset.py +++ b/tensorflow/python/keras/preprocessing/image_dataset.py @@ -28,7 +28,7 @@ from tensorflow.python.ops import io_ops from tensorflow.python.util.tf_export import keras_export -WHITELIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png') +ALLOWLIST_FORMATS = ('.bmp', '.gif', '.jpeg', '.jpg', '.png') @keras_export('keras.preprocessing.image_dataset_from_directory', v1=[]) @@ -175,7 +175,7 @@ def image_dataset_from_directory(directory, image_paths, labels, class_names = dataset_utils.index_directory( directory, labels, - formats=WHITELIST_FORMATS, + formats=ALLOWLIST_FORMATS, class_names=class_names, shuffle=shuffle, seed=seed, diff --git a/tensorflow/python/ops/while_v2.py b/tensorflow/python/ops/while_v2.py index 8e92c44e707..30d4c6d235a 100644 --- a/tensorflow/python/ops/while_v2.py +++ b/tensorflow/python/ops/while_v2.py @@ -865,7 +865,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): This only allows capturing tensors in the forward graph. A ValueError is raised if an attempt is made to capture a tensor not in the forward graph. To manually capture capture a tensor that is not in the forward graph, call - `capture` with `whitelisted=True`. + `capture` with `allowlisted=True`. Note: The `captures` dict does not contain the forward tensor since it is not directly captured. It contains the accumulator corresponding to this forward @@ -968,16 +968,16 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): op_def=op_def, compute_device=compute_device) - def capture(self, tensor, name=None, whitelisted=False): + def capture(self, tensor, name=None, allowlisted=False): """Selectively captures external tensors. - If `whitelisted` is False only allows capturing tensors in the + If `allowlisted` is False only allows capturing tensors in the `_forward_graph`. Args: tensor: Tensor. May be from this FuncGraph or a different graph. name: Optional name if a placeholder is created. - whitelisted: If False (default), only allows capturing tensors from the + allowlisted: If False (default), only allows capturing tensors from the forward graph. Returns: @@ -985,9 +985,9 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): Raises: ValueError: If attempting to capture an external tensor not in the forward - graph with `whitelisted` set to False. + graph with `allowlisted` set to False. """ - if not whitelisted and (isinstance(tensor, ops.EagerTensor) or + if not allowlisted and (isinstance(tensor, ops.EagerTensor) or (tensor.graph is not self and tensor.graph != self._forward_graph)): with self._forward_cond_graph.as_default(): @@ -1136,7 +1136,7 @@ class _WhileBodyGradFuncGraph(util.WhileBodyFuncGraph): "Resource tensors must be loop invariants %s." % tensor_in_outer_graph) self._indirect_captures[ops.tensor_id(tensor)] = self.capture( - tensor_in_outer_graph, whitelisted=True) + tensor_in_outer_graph, allowlisted=True) return self._indirect_captures[ops.tensor_id(tensor)] diff --git a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py index 70e63133bdb..e8432618a21 100644 --- a/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py +++ b/tensorflow/python/ops/while_v2_indexed_slices_rewriter.py @@ -143,10 +143,10 @@ def _rewrite_input_as_indexed_slices(body_grad_graph, grad_output_slices, # computation. with body_grad_graph.as_default(): input_slices = ops.IndexedSlices( - values=body_grad_graph.capture(init_slices.values, whitelisted=True), - indices=body_grad_graph.capture(init_slices.indices, whitelisted=True), - dense_shape=body_grad_graph.capture(init_slices.dense_shape, - whitelisted=True)) + values=body_grad_graph.capture(init_slices.values, allowlisted=True), + indices=body_grad_graph.capture(init_slices.indices, allowlisted=True), + dense_shape=body_grad_graph.capture( + init_slices.dense_shape, allowlisted=True)) # Remove the captured tensors from the function inputs. We'll add them back # at the correct index in _update_indexed_slices_param. diff --git a/tensorflow/python/tools/selective_registration_header_lib.py b/tensorflow/python/tools/selective_registration_header_lib.py index 25f538b4f67..ff2b3dba318 100644 --- a/tensorflow/python/tools/selective_registration_header_lib.py +++ b/tensorflow/python/tools/selective_registration_header_lib.py @@ -36,7 +36,7 @@ from tensorflow.python.platform import tf_logging # corresponding kernel; nodes without a corresponding kernel (perhaps due to # attr types) generate a warning but are otherwise ignored. Ops in this set are # registered even if there's no corresponding kernel. -OPS_WITHOUT_KERNEL_WHITELIST = frozenset([ +OPS_WITHOUT_KERNEL_ALLOWLIST = frozenset([ # AccumulateNV2 is rewritten away by AccumulateNV2RemovePass; see # core/common_runtime/accumulate_n_optimizer.cc. 'AccumulateNV2' @@ -67,7 +67,7 @@ def _get_ops_from_graphdef(graph_def): kernel_class = _pywrap_kernel_registry.TryFindKernelClass( node_def.SerializeToString()) op = str(node_def.op) - if kernel_class or op in OPS_WITHOUT_KERNEL_WHITELIST: + if kernel_class or op in OPS_WITHOUT_KERNEL_ALLOWLIST: op_and_kernel = (op, str(kernel_class.decode('utf-8')) if kernel_class else None) ops.add(op_and_kernel) diff --git a/tensorflow/python/tpu/feature_column_test.py b/tensorflow/python/tpu/feature_column_test.py index 0b4e84a6212..74cfe27f006 100644 --- a/tensorflow/python/tpu/feature_column_test.py +++ b/tensorflow/python/tpu/feature_column_test.py @@ -68,7 +68,7 @@ class EmbeddingColumnTest(test.TestCase): tpu_fc.embedding_column(categorical_column, dimension=embedding_dimension) def test_custom_column(self): - # This column is not in any whitelist but should succeed because + # This column is not in any allowlist but should succeed because # it inherits from V2 CategoricalColumn. categorical_column = fc_lib.categorical_column_with_identity( key='aaa', num_buckets=10) diff --git a/tensorflow/python/training/experimental/mixed_precision.py b/tensorflow/python/training/experimental/mixed_precision.py index 38377dd0600..c41ec38ccef 100644 --- a/tensorflow/python/training/experimental/mixed_precision.py +++ b/tensorflow/python/training/experimental/mixed_precision.py @@ -122,7 +122,7 @@ def enable_mixed_precision_graph_rewrite(opt, loss_scale='dynamic'): * `ClearList`: Ops that do not have numerically significant adverse effects. E.g. `ArgMax` and `Floor`. - * `WhiteList`: Ops that are considered numerically safe for execution in + * `AllowList`: Ops that are considered numerically safe for execution in float16, and thus are always converted. E.g. `Conv2D`. * `BlackList`: Ops that are numerically unsafe to execute in float16 and can negatively affect downstream nodes. E.g. `Softmax`. @@ -267,7 +267,7 @@ def enable_mixed_precision_graph_rewrite_v1(opt, loss_scale='dynamic'): * `ClearList`: Ops that do not have numerically significant adverse effects. E.g. `ArgMax` and `Floor`. - * `WhiteList`: Ops that are considered numerically safe for execution in + * `AllowList`: Ops that are considered numerically safe for execution in float16, and thus are always converted. E.g. `Conv2D`. * `BlackList`: Ops that are numerically unsafe to execute in float16 and can negatively affect downstream nodes. E.g. `Softmax`. diff --git a/tensorflow/python/util/all_util.py b/tensorflow/python/util/all_util.py index f2e4499b64b..fa9b4f107d8 100644 --- a/tensorflow/python/util/all_util.py +++ b/tensorflow/python/util/all_util.py @@ -93,7 +93,7 @@ def remove_undocumented(module_name, allowed_exception_list=None, doc_string_modules: a list of modules from which to take the docstrings. If None, then a list containing only the module named `module_name` is used. - Furthermore, if a symbol previously added with `add_to_global_whitelist`, + Furthermore, if a symbol previously added with `add_to_global_allowlist`, then it will always be allowed. This is useful for internal tests. Returns: diff --git a/tensorflow/tools/ci_build/ci_sanity.sh b/tensorflow/tools/ci_build/ci_sanity.sh index 6db88755ac8..eb9a5a2a96e 100755 --- a/tensorflow/tools/ci_build/ci_sanity.sh +++ b/tensorflow/tools/ci_build/ci_sanity.sh @@ -96,8 +96,8 @@ do_pylint() { # --incremental Performs check on only the python files changed in the # last non-merge git commit. - # Use this list to whitelist pylint errors - ERROR_WHITELIST="^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline "\ + # Use this list to allowlist pylint errors + ERROR_ALLOWLIST="^tensorflow/python/framework/function_test\.py.*\[E1123.*noinline "\ "^tensorflow/python/platform/default/_gfile\.py.*\[E0301.*non-iterator "\ "^tensorflow/python/platform/default/_googletest\.py.*\[E0102.*function\salready\sdefined "\ "^tensorflow/python/feature_column/feature_column_test\.py.*\[E0110.*abstract-class-instantiated "\ @@ -115,7 +115,7 @@ do_pylint() { "^tensorflow/python/autograph/.*_py3_test\.py.*\[E0001.*syntax-error "\ "^tensorflow/python/keras/preprocessing/image\.py.*\[E0240.*Inconsistent method resolution " - echo "ERROR_WHITELIST=\"${ERROR_WHITELIST}\"" + echo "ERROR_ALLOWLIST=\"${ERROR_ALLOWLIST}\"" if [[ $# != "0" ]] && [[ $# != "1" ]]; then echo "Invalid syntax when invoking do_pylint" @@ -195,16 +195,16 @@ do_pylint() { N_ERRORS=0 while read -r LINE; do - IS_WHITELISTED=0 - for WL_REGEX in ${ERROR_WHITELIST}; do + IS_ALLOWLISTED=0 + for WL_REGEX in ${ERROR_ALLOWLIST}; do if echo ${LINE} | grep -q "${WL_REGEX}"; then - echo "Found a whitelisted error:" + echo "Found a allowlisted error:" echo " ${LINE}" - IS_WHITELISTED=1 + IS_ALLOWLISTED=1 fi done - if [[ ${IS_WHITELISTED} == "0" ]]; then + if [[ ${IS_ALLOWLISTED} == "0" ]]; then echo "${LINE}" >> ${NONWL_ERRORS_FILE} echo "" >> ${NONWL_ERRORS_FILE} ((N_ERRORS++)) @@ -213,11 +213,11 @@ do_pylint() { echo "" if [[ ${N_ERRORS} != 0 ]]; then - echo "FAIL: Found ${N_ERRORS} non-whitelisted pylint errors:" + echo "FAIL: Found ${N_ERRORS} non-allowlisted pylint errors:" cat "${NONWL_ERRORS_FILE}" return 1 else - echo "PASS: No non-whitelisted pylint errors were found." + echo "PASS: No non-allowlisted pylint errors were found." return 0 fi } @@ -370,7 +370,7 @@ do_external_licenses_check(){ -v ${MISSING_LICENSES_FILE} > temp.txt mv temp.txt ${MISSING_LICENSES_FILE} - # Whitelist + # Allowlist echo ${EXTRA_LICENSE_FILE} grep \ -e "//third_party/mkl" \ diff --git a/tensorflow/tools/test/check_futures_test.py b/tensorflow/tools/test/check_futures_test.py index 353fb694bc8..225986da9ae 100644 --- a/tensorflow/tools/test/check_futures_test.py +++ b/tensorflow/tools/test/check_futures_test.py @@ -40,7 +40,7 @@ FUTURES_PATTERN_2 = re.compile( FUTURES_PATTERN_3 = re.compile(r'^from __future__ import (\w+) as \w+\s*$') REQUIRED_FUTURES = frozenset(['absolute_import', 'division', 'print_function']) -WHITELIST = [ +ALLOWLIST = [ 'python/platform/control_imports.py', 'tools/docker/jupyter_notebook_config.py', 'tools/ci_build/update_version.py', @@ -93,12 +93,12 @@ def main(): BASE_DIR) # Verify that all files have futures - whitelist = frozenset(os.path.join(BASE_DIR, w) for w in WHITELIST) + allowlist = frozenset(os.path.join(BASE_DIR, w) for w in ALLOWLIST) old_division = frozenset(os.path.join(BASE_DIR, w) for w in OLD_DIVISION) for root, _, filenames in os.walk(BASE_DIR): for f in fnmatch.filter(filenames, '*.py'): path = os.path.join(root, f) - if path not in whitelist: + if path not in allowlist: try: check_file(path, old_division=path in old_division) except AssertionError as e: From d35df6dc46984380924c11da9e8f76df2c9df10b Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 8 Jul 2020 10:05:32 -0700 Subject: [PATCH 53/88] Rename xla_lhlo dialect into lmhlo Following on the plan of isolating the compiler/mlir/hlo directory. Another xla_lhlo dialect will be created under compiler/mlir/xla/ later. PiperOrigin-RevId: 320210326 Change-Id: I25147bb7687c5efeb0e61f8f9ffef27b022af5b1 --- .../mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h | 10 +- .../mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td | 10 +- .../mhlo/transforms/map_hlo_to_lhlo_op.h | 2 +- .../mhlo/transforms/map_xla_to_scalar_op.h | 186 ++++++------ .../mlir-hlo/Dialect/mhlo/transforms/passes.h | 4 +- .../Dialect/mhlo/transforms/rewriters.h | 4 +- .../Dialect/mhlo/IR/dialect_registration.cc | 2 +- .../mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc | 6 +- .../mhlo/transforms/hlo_legalize_to_lhlo.cc | 46 +-- .../mhlo/transforms/lhlo_copy_removal.cc | 6 +- .../mhlo/transforms/lhlo_fuse_linalg.cc | 4 +- .../transforms/lhlo_legalize_to_affine.cc | 22 +- .../mhlo/transforms/lhlo_legalize_to_gpu.cc | 6 +- .../mhlo/transforms/lhlo_legalize_to_llvm.cc | 4 +- .../transforms/lhlo_legalize_to_llvm_pass.cc | 6 +- .../lhlo_legalize_to_parallel_loops.cc | 62 ++-- .../mhlo/transforms/xla_legalize_to_linalg.cc | 120 ++++---- .../mlir/hlo/tests/hlo-legalize-to-lhlo.mlir | 78 ++--- .../mlir/hlo/tests/lhlo-copy-removal.mlir | 68 ++--- .../lhlo-legalize-select-and-scatter.mlir | 16 +- .../hlo/tests/lhlo-legalize-to-affine.mlir | 32 +-- .../mlir/hlo/tests/lhlo-legalize-to-gpu.mlir | 8 +- .../hlo/tests/lhlo-legalize-to-linalg.mlir | 108 +++---- .../mlir/hlo/tests/lhlo-legalize-to-llvm.mlir | 4 +- .../lhlo-legalize-to-parallel-loops.mlir | 32 +-- .../compiler/mlir/hlo/tests/lhlo_ops.mlir | 270 +++++++++--------- .../mlir/tools/kernel_gen/cubin_creator.cc | 2 +- .../xla/tests/hlo_to_lhlo_with_xla/ops.mlir | 116 ++++---- .../hlo_to_lhlo_with_xla/passthrough.mlir | 6 +- .../xla/transforms/mhlo_to_lhlo_with_xla.cc | 54 ++-- .../xla/transforms/mhlo_to_lhlo_with_xla.h | 2 +- .../xla/service/mlir_gpu/kernel_lowering.cc | 12 +- .../service/mlir_gpu/lhlo_dialect_emitter.cc | 2 +- .../xla/service/mlir_gpu/tests/abs.hlo | 2 +- .../xla/service/mlir_gpu/tests/add.hlo | 2 +- .../service/mlir_gpu/tests/add_multiply.hlo | 4 +- .../xla/service/mlir_gpu/tests/add_reduce.hlo | 4 +- .../xla/service/mlir_gpu/tests/broadcast.hlo | 2 +- .../xla/service/mlir_gpu/tests/broken_add.hlo | 2 +- .../xla/service/mlir_gpu/tests/ceil.hlo | 2 +- .../xla/service/mlir_gpu/tests/compare.hlo | 2 +- .../xla/service/mlir_gpu/tests/complex.hlo | 2 +- .../service/mlir_gpu/tests/concatenate.hlo | 2 +- .../xla/service/mlir_gpu/tests/const.hlo | 4 +- .../xla/service/mlir_gpu/tests/copy.hlo | 2 +- .../service/mlir_gpu/tests/copy_transpose.hlo | 2 +- .../xla/service/mlir_gpu/tests/cos.hlo | 2 +- .../xla/service/mlir_gpu/tests/exp.hlo | 2 +- .../service/mlir_gpu/tests/fused_reduce.hlo | 4 +- .../xla/service/mlir_gpu/tests/gather.hlo | 2 +- .../xla/service/mlir_gpu/tests/imag.hlo | 2 +- .../xla/service/mlir_gpu/tests/iota.hlo | 2 +- .../xla/service/mlir_gpu/tests/log.hlo | 2 +- .../xla/service/mlir_gpu/tests/neg.hlo | 2 +- .../xla/service/mlir_gpu/tests/real.hlo | 2 +- .../service/mlir_gpu/tests/reduce_window.hlo | 4 +- .../xla/service/mlir_gpu/tests/rem.hlo | 2 +- .../xla/service/mlir_gpu/tests/rsqrt.hlo | 2 +- .../xla/service/mlir_gpu/tests/select.hlo | 2 +- .../mlir_gpu/tests/select_and_scatter.hlo | 6 +- .../xla/service/mlir_gpu/tests/sign.hlo | 2 +- .../xla/service/mlir_gpu/tests/sqrt.hlo | 2 +- .../xla/service/mlir_gpu/tests/tanh.hlo | 2 +- 63 files changed, 700 insertions(+), 684 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h index ad1aa78b7f8..fd31bec44c0 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h @@ -35,18 +35,18 @@ class OpBuilder; #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.h.inc" -namespace xla_lhlo { +namespace lmhlo { -class XlaLhloDialect : public Dialect { +class LmhloDialect : public Dialect { public: - explicit XlaLhloDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_lhlo"; } + explicit LmhloDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "lmhlo"; } }; #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h.inc" -} // namespace xla_lhlo +} // namespace lmhlo } // end namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_LHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td index 46561bb8a03..407c260e9c8 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.td @@ -38,8 +38,8 @@ include "mlir/Interfaces/ViewLikeInterface.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def LHLO_Dialect : Dialect { - let name = "xla_lhlo"; - let cppNamespace = "xla_lhlo"; + let name = "lmhlo"; + let cppNamespace = "lmhlo"; } //===----------------------------------------------------------------------===// @@ -253,7 +253,7 @@ def LHLO_ReduceWindowOp: LHLO_Op<"reduce_window", [ // TODO(timshen): Add a custom parser to hide operand_segment_sizes. For example, // A tuple-like pattern match syntax could work: -// xla_lhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { +// lmhlo.case %index, (%input0, %input1, %input2), (%output0, %output1) { // ... // }, { // ... @@ -337,7 +337,7 @@ def HLO_StaticMemRefCastOp: Op -> memref<5xf32, offset: 2, strides: [1]> // The result of the op is a rank-1 memref with `[5]` shape, stride 1 and @@ -379,7 +379,7 @@ def HLO_DynamicMemRefCastOp: Op -> memref // The result of the op is a type-erased memref with `[%size_X, %size_Y]` // shape and `[%step_X, %step_Y]` strides. The offset will be inherited diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h index 9349ab653e4..a0246f93180 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h @@ -34,7 +34,7 @@ using HloToLhloOp = typename HloToLhloOpImpl::Type; #define MAP_HLO_TO_LHLO(OpName) \ template <> \ struct HloToLhloOpImpl { \ - using Type = xla_lhlo::OpName; \ + using Type = lmhlo::OpName; \ } MAP_HLO_TO_LHLO(AbsOp); diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h index ddd46d327af..1ed08754277 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h @@ -24,7 +24,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_hlo_to_lhlo_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace impl { // A struct to map LhloBinaryOpTy type to the corresponding floating-point and @@ -33,32 +33,32 @@ template struct LhloToScalarOp; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::AddFOp; using IOp = ::mlir::AddIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::CmpFOp; using IOp = ::mlir::CmpIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::DivFOp; using IOp = ::mlir::SignedDivIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::MulFOp; using IOp = ::mlir::MulIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::RemFOp; using IOp = ::mlir::SignedRemIOp; }; template <> -struct LhloToScalarOp { +struct LhloToScalarOp { using FOp = ::mlir::SubFOp; using IOp = ::mlir::SubIOp; }; @@ -116,16 +116,17 @@ inline Value MapLhloOpToStdScalarOp(Location loc, ArrayRef result_types, } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { - // xla_lhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) + // lmhlo.abs(x, result) -> result = select((x > 0), x, sub(0, x)) Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); @@ -133,16 +134,17 @@ inline Value MapLhloOpToStdScalarOp( b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); auto lhs_gt_zero = b->create>(loc, CmpIPredicate::sge, lhs, zero_intval); - auto neg_val = b->create>(loc, zero_intval, lhs); + auto neg_val = b->create>(loc, zero_intval, lhs); return b->create<::mlir::SelectOp>(loc, lhs_gt_zero, lhs, neg_val); } return nullptr; } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -205,30 +207,33 @@ inline Value MapXlaCompareOpToStdScalarOp(Location loc, } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return args.front(); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, @@ -236,21 +241,23 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}(loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { Type sourceType = args.front().getType(); @@ -288,9 +295,10 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { // Dot Op converter from lhlo to affine only accepts float and integer types. const auto& lhs = args[0]; const auto& rhs = args[1]; @@ -312,17 +320,19 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -361,66 +371,69 @@ struct XlaCompareSelectOpToStdScalarOp -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return XlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "GT", - result_types, args, - b); + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "GT", result_types, + args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return XlaCompareSelectOpToStdScalarOp< - IntegerType, ScalarIOp, CmpIPredicate, FloatType, - ScalarFOp, CmpFPredicate>::map(loc, "LT", - result_types, args, - b); + IntegerType, ScalarIOp, CmpIPredicate, FloatType, + ScalarFOp, CmpFPredicate>::map(loc, "LT", result_types, + args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } if (element_type.isa()) { - // xla_lhlo.neg(x, result) -> result = sub(0, x) + // lmhlo.neg(x, result) -> result = sub(0, x) Value lhs = args[0]; auto integer_type = element_type.dyn_cast(); auto zero_intval = b->create<::mlir::ConstantIntOp>(loc, 0, integer_type.getWidth()); - return b->create>(loc, zero_intval, lhs); + return b->create>(loc, zero_intval, lhs); } return nullptr; } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( +inline Value MapLhloOpToStdScalarOp( Location loc, ArrayRef result_types, ArrayRef args, OpBuilder* b) { return MapLhloOpToStdScalarOpImpl<::mlir::SelectOp>{}(loc, result_types, args, @@ -428,9 +441,10 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { Type element_type = args.front().getType(); if (element_type.isa()) { FloatType float_type = element_type.cast(); @@ -442,17 +456,19 @@ inline Value MapLhloOpToStdScalarOp( } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } template <> -inline Value MapLhloOpToStdScalarOp( - Location loc, ArrayRef result_types, ArrayRef args, - OpBuilder* b) { +inline Value MapLhloOpToStdScalarOp(Location loc, + ArrayRef result_types, + ArrayRef args, + OpBuilder* b) { return MapLhloOpToStdScalarOpImpl{}( loc, result_types, args, b); } @@ -460,10 +476,10 @@ inline Value MapLhloOpToStdScalarOp( } // namespace impl struct XlaOpToStdScalarOp { - // Implementation for LHLO ops except xla_lhlo::CompareOp. + // Implementation for LHLO ops except lmhlo::CompareOp. template ::value && + !std::is_same::value && std::is_same, std::false_type>::value>> static Value map(XlaOpTy op, ArrayRef result_types, @@ -475,7 +491,7 @@ struct XlaOpToStdScalarOp { // Implementation for HLO ops except mhlo::CompareOp. template , typename = std::enable_if_t< - !std::is_same::value && + !std::is_same::value && !std::is_same::value>> static Value map(XlaOpTy op, ArrayRef result_types, ArrayRef args, OpBuilder* b, int i = 0) { @@ -483,13 +499,13 @@ struct XlaOpToStdScalarOp { args, b); } - // Implementation for xla_lhlo::CompareOp. + // Implementation for lmhlo::CompareOp. template ::value>> - static Value map(xla_lhlo::CompareOp op, ArrayRef result_types, + LhloOpTy, lmhlo::CompareOp>::value>> + static Value map(lmhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapXlaCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } @@ -500,12 +516,12 @@ struct XlaOpToStdScalarOp { static Value map(mhlo::CompareOp op, ArrayRef result_types, ArrayRef args, OpBuilder* b) { auto comparison_direction = op.comparison_direction(); - return impl::MapXlaCompareOpToStdScalarOp( + return impl::MapXlaCompareOpToStdScalarOp( op.getLoc(), comparison_direction, result_types, args, b); } }; -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_TRANSFORMS_MAP_XLA_TO_SCALAR_OP_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h index a32de2ba6bb..db34a08b86d 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h @@ -60,7 +60,7 @@ std::unique_ptr> createXlaHloFusionPass(); } // namespace mhlo -namespace xla_lhlo { +namespace lmhlo { // Lowers from LHLO dialect to Affine dialect. std::unique_ptr> createLegalizeToAffinePass(); @@ -92,7 +92,7 @@ std::unique_ptr createLhloCopyRemovalPass(); // Lowers from LHLO dialect to parallel loops. std::unique_ptr> createLegalizeLhloToParallelLoopsPass(); -} // namespace xla_lhlo +} // namespace lmhlo namespace xla { diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index 25e55fde904..fbff046ac40 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -75,14 +75,14 @@ void PopulateUnfuseBatchNormPatterns(MLIRContext *context, } // namespace mhlo -namespace xla_lhlo { +namespace lmhlo { /// Collect a set of patterns to convert from the LHLO dialect to LLVM. void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, LLVMTypeConverter *converter, OwningRewritePatternList *patterns); -} // namespace xla_lhlo +} // namespace lmhlo namespace xla_chlo { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc index 6b8350ce13d..6d27d6015c2 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -21,4 +21,4 @@ limitations under the License. static mlir::DialectRegistration mhlo_ops; static mlir::DialectRegistration xla_chlo_ops; -static mlir::DialectRegistration xla_lhlo_ops; +static mlir::DialectRegistration lmhlo_ops; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc index 3a0d7ebfc64..49117f4bb29 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/lhlo_ops.cc @@ -46,9 +46,9 @@ limitations under the License. namespace mlir { #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_structs.cc.inc" -namespace xla_lhlo { +namespace lmhlo { -XlaLhloDialect::XlaLhloDialect(MLIRContext *context) +LmhloDialect::LmhloDialect(MLIRContext *context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -138,5 +138,5 @@ void FusionOp::build(OpBuilder &builder, OperationState &result, FusionOp::ensureTerminator(*bodyRegion, builder, result.location); } -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index 1162b0ecb6b..adf54828ba7 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -44,7 +44,7 @@ template using BaseOpConversion = BufferAssignmentOpConversionPattern; using StdReturnOpConverter = detail::BufferAssignmentReturnOpConverter; + lmhlo::CopyOp, true>; Value InsertDynamicAllocAndDealloc(Location loc, Value result, Value shape_operand, @@ -149,7 +149,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter Value transformed_operand = InsertDynamicMemrefCastOp(op, operands.front(), &rewriter); - rewriter.create( + rewriter.create( loc, transformed_operand, resultBuffer, op.broadcast_dimensions()); rewriter.replaceOp(op, {resultBuffer}); @@ -161,7 +161,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter // Inserts dynamic memref to change the layout of the memref to put 0-stride // and size of the target dimension if size-1 dimension expansion is // necessary. - xla_lhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( + lmhlo::DynamicMemRefCastOp InsertDynamicMemrefCastOp( mhlo::DynamicBroadcastInDimOp op, Value operand, OpBuilder* b) const { auto loc = op.getLoc(); auto operand_type = operand.getType().cast(); @@ -214,7 +214,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter makeStridedLinearLayoutMap(dynamic_layout, /*offset=*/0, b->getContext())); - auto transformed_operand = b->create( + auto transformed_operand = b->create( loc, type_erased_memref_type, operand, sizes, strides); return transformed_operand; } @@ -239,7 +239,7 @@ struct HloToLhloDynamicReshapeConverter return failure(); } mhlo::DynamicReshapeOp::Adaptor adaptor(operands); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, adaptor.operand(), adaptor.output_shape()); return success(); } @@ -266,8 +266,8 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { buffer_args.push_back( InsertAlloc(loc, result, this->bufferAssignment, &rewriter)); } - auto new_op = rewriter.create( - loc, llvm::None, buffer_args, op.getAttrs()); + auto new_op = rewriter.create(loc, llvm::None, buffer_args, + op.getAttrs()); // Copy over the operations inside the region. rewriter.inlineRegionBefore(op.body(), new_op.body(), new_op.body().end()); @@ -292,7 +292,7 @@ struct HloToLhloReduceOpConverter : public BaseOpConversion { } // Insert terminator at the end. rewriter.setInsertionPointToEnd(&entry_block); - rewriter.create(loc); + rewriter.create(loc); rewriter.replaceOp(op, ArrayRef(buffer_args).slice(operands.size())); @@ -321,8 +321,8 @@ class HloToLhloTensorStoreOpConverter LogicalResult matchAndRewrite( mlir::TensorStoreOp op, ArrayRef operands, ConversionPatternRewriter& rewriter) const final { - rewriter.replaceOpWithNewOp( - op, llvm::None, operands.front(), operands.back()); + rewriter.replaceOpWithNewOp(op, llvm::None, operands.front(), + operands.back()); return success(); } }; @@ -336,7 +336,7 @@ class HloToLhloTensorStoreOpConverter // %arg1: memref<2x2xf32>, // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { -// "xla_lhlo.fusion"() ({ +// "lmhlo.fusion"() ({ // %0 = tensor_load %arg1 : memref<2x2xf32> // %1 = tensor_load %arg2 : memref<2x2xf32> // %2 = "mhlo.add"(%0, %1) : @@ -345,7 +345,7 @@ class HloToLhloTensorStoreOpConverter // %4 = "mhlo.multiply"(%2, %3) : // (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> // tensor_store %4, %arg3 : memref<2x2xf32> -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) : () -> () // return // } @@ -355,13 +355,13 @@ class HloToLhloTensorStoreOpConverter // %arg1: memref<2x2xf32>, // %arg2: memref<2x2xf32>, // %arg3: memref<2x2xf32>) { -// "xla_lhlo.fusion"() ( { +// "lmhlo.fusion"() ( { // %0 = alloc() : memref<2x2xf32> -// "xla_lhlo.add"(%arg1, %arg2, %0) : +// "lmhlo.add"(%arg1, %arg2, %0) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// "xla_lhlo.multiply"(%0, %arg0, %arg3) : +// "lmhlo.multiply"(%0, %arg0, %arg3) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) : () -> () // return // } @@ -382,13 +382,13 @@ class HloToLhloTensorStoreOpConverter // %arg2: memref<4xf32>) { // %0 = alloc() : memref<4xf32> -// "xla_lhlo.maximum"(%arg0, %arg1, %0) : +// "lmhlo.maximum"(%arg0, %arg1, %0) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () // %1 = alloc() : memref<4xf32> -// "xla_lhlo.add"(%arg0, %0, %1) : +// "lmhlo.add"(%arg0, %0, %1) : // (memref<4xf32>, memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.copy"(%1, %arg2) : (memref<4xf32>, memref<4xf32>) -> () +// "lmhlo.terminator"() : () -> () // } struct HloLegalizeToLhlo @@ -406,7 +406,7 @@ struct HloLegalizeToLhlo OwningRewritePatternList patterns; auto& context = getContext(); ConversionTarget target(context); - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); target.addIllegalOp(); @@ -441,12 +441,12 @@ struct HloLegalizeToLhlo &converter, &patterns); if (results_escape_function) { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, /*allowMemrefFunctionResults=*/true>(&context, &bufferAssignment, &converter, &patterns); } else { populateWithBufferAssignmentOpConversionPatterns< - mlir::ReturnOp, mlir::ReturnOp, xla_lhlo::CopyOp, + mlir::ReturnOp, mlir::ReturnOp, lmhlo::CopyOp, /*allowMemrefFunctionResults=*/false>(&context, &bufferAssignment, &converter, &patterns); } diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc index 145cd75b61c..d2607887482 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_copy_removal.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Removes LHLO copy operations that copy from allocated buffers to block @@ -34,7 +34,7 @@ struct LhloCopyRemoval : mlir::PassWrapper> { void runOnOperation() override { llvm::SmallVector eraseList; auto operation = getOperation(); - operation->walk([&](mlir::xla_lhlo::CopyOp copyOp) { + operation->walk([&](mlir::lmhlo::CopyOp copyOp) { // If this region contains more than one block, then ignore this copy // operation. if (copyOp.getParentRegion()->getBlocks().size() > 1) { @@ -101,5 +101,5 @@ std::unique_ptr createLhloCopyRemovalPass() { static PassRegistration copy_removal_pass( "lhlo-copy-removal", "Removes redundant LHLO copy operations"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc index 5efb0fa78e5..d832b96bf7b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_fuse_linalg.cc @@ -27,7 +27,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/passes.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { using linalg::LinalgOp; @@ -147,5 +147,5 @@ static PassRegistration legalize_pass( "lhlo-fuse-linalg", "Greedily fuse linalg ops obtained after LHLO lowering."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc index e87125df86d..1b13eb46b2a 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_affine.cc @@ -28,7 +28,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Builds an affine loop nest iterating from zeros to "upper_bounds" with unit @@ -69,7 +69,7 @@ struct DotOpConverter : public OpRewritePattern { auto r = builder.create(loc, rhs, rhs_indices); auto result = rewriter.create(loc, op.output(), result_indices); - Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::XlaOpToStdScalarOp::map( op, element_type, {l, r, result}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; @@ -108,7 +108,7 @@ struct BinaryOpConverter : public OpRewritePattern { ValueRange induction_vars) { auto l = builder.create(loc, lhs, induction_vars); auto r = builder.create(loc, rhs, induction_vars); - Value op_result = xla_lhlo::XlaOpToStdScalarOp::map( + Value op_result = lmhlo::XlaOpToStdScalarOp::map( op, element_type, {l, r}, &builder); map_status = success(op_result != nullptr); if (failed(map_status)) return; @@ -127,13 +127,13 @@ void populateLHLOToAffineConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off patterns->insert< - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, - BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, + BinaryOpConverter, DotOpConverter>(context); // clang-format on } @@ -157,5 +157,5 @@ std::unique_ptr> createLegalizeToAffinePass() { static PassRegistration legalize_pass( "lhlo-legalize-to-affine", "Legalize from LHLO dialect to affine dialect"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc index 7489a092c27..1fcb881dd7b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_gpu.cc @@ -38,7 +38,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/map_xla_to_scalar_op.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // A simple translation of LHLO reduce operations to a corresponding gpu @@ -173,7 +173,7 @@ struct LhloLegalizeToGpu : public PassWrapper { OwningRewritePatternList patterns; ConversionTarget target(getContext()); target.addLegalDialect(); + gpu::GPUDialect, scf::SCFDialect, LmhloDialect>(); target.addIllegalOp(); auto func = getFunction(); patterns.insert(func.getContext()); @@ -192,5 +192,5 @@ std::unique_ptr> createLegalizeToGpuPass() { static PassRegistration legalize_pass( "lhlo-legalize-to-gpu", "Legalize from LHLO dialect to GPU dialect"); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc index 6ae3c334493..67768d56de2 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { struct StaticMemRefCastOpConverter @@ -132,5 +132,5 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, *converter, options); } -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index ade121423cf..033bbaf210e 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { class TestLhloToLLVMPass @@ -42,7 +42,7 @@ class TestLhloToLLVMPass ConversionTarget target(getContext()); target.addLegalDialect(); target.addLegalOp(); - target.addIllegalDialect(); + target.addIllegalDialect(); if (failed(applyFullConversion(m, target, patterns))) { signalPassFailure(); @@ -55,5 +55,5 @@ class TestLhloToLLVMPass static PassRegistration legalize_lhlo_pass( "test-lhlo-legalize-to-llvm", "Legalize from LHLO dialect to LLVM."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc index 6f4da98db65..17d629ba699 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_parallel_loops.cc @@ -26,7 +26,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" namespace mlir { -namespace xla_lhlo { +namespace lmhlo { namespace { // Clones and adapts the code in `lhlo_block` that works on buffers and has a @@ -154,14 +154,14 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, return b->create(loc, lower, upper, step); } -// Converts `xla_lhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. +// Converts `lmhlo.ReduceOp` into two scf::ParallelOp and a scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops if there are // any. The inner `ParalleOp` refers to the reduction loops and `ReduceOp` // contains the reduction operator. // // Example: // -// "xla_lhlo.reduce"(%buffer, %init_buf, %result) ( { +// "lmhlo.reduce"(%buffer, %init_buf, %result) ( { // ^bb0(%lhs: memref, %rhs: memref, %res: memref): // // } ) {dimensions = dense<[1]> : tensor<1xi64>} @@ -187,12 +187,12 @@ scf::ParallelOp MakeLoopOverShape(Location loc, Value shaped_value, // } : f32 // scf.yield // } -class ReduceOpConverter : public OpConversionPattern { +class ReduceOpConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, + lmhlo::ReduceOp xla_reduce_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { // TODO(b/137624192) Implement variadic reduce. if (xla_reduce_op.out().size() != 1) return failure(); @@ -226,7 +226,7 @@ class ReduceOpConverter : public OpConversionPattern { // scf.yield // } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - xla_lhlo::ReduceOp xla_reduce_op, + lmhlo::ReduceOp xla_reduce_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_op.getLoc(); DenseSet reducing_dims; @@ -314,7 +314,7 @@ class ReduceOpConverter : public OpConversionPattern { // accumulator = reduction_operator(output[O], value) // output[O] = accumulator // -// Converts `xla_lhlo.ReduceWindowOp` into two scf::ParallelOp and a +// Converts `lmhlo.ReduceWindowOp` into two scf::ParallelOp and a // scf::ReduceOp. // The outper `ParallelOp` refers to the parallel loops that traverese output // buffer. The inner `ParalleOp` refers to the reduction loops that traverse @@ -325,11 +325,11 @@ class ReduceOpConverter : public OpConversionPattern { // func @reduce_window(%arg: memref<112x112xf32>, // %init: memref, // %result: memref<56x56xf32>) { -// "xla_lhlo.reduce_window"(%arg, %init, %result) ( { +// "lmhlo.reduce_window"(%arg, %init, %result) ( { // ^bb0(%lhs: memref, %rhs: memref, %res: memref): -// "xla_lhlo.maximum"(%lhs, %rhs, %res) +// "lmhlo.maximum"(%lhs, %rhs, %res) // : (memref, memref, memref) -> () -// "xla_lhlo.terminator"() : () -> () +// "lmhlo.terminator"() : () -> () // }) { // padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, // window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -359,12 +359,12 @@ class ReduceOpConverter : public OpConversionPattern { // return // } class ReduceWindowOpConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, + lmhlo::ReduceWindowOp xla_reduce_window_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { scf::ParallelOp output_loop, window_loop; std::tie(output_loop, window_loop) = @@ -383,7 +383,7 @@ class ReduceWindowOpConverter private: std::pair CreateParallelLoopsToTraverseOutputAndWindow( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, + lmhlo::ReduceWindowOp xla_reduce_window_op, ConversionPatternRewriter* rewriter) const { auto loc = xla_reduce_window_op.getLoc(); Value init_value = @@ -415,9 +415,8 @@ class ReduceWindowOpConverter } scf::ReduceOp CreateReduceOpInNestedParallelLoops( - xla_lhlo::ReduceWindowOp xla_reduce_window_op, - scf::ParallelOp output_loop, scf::ParallelOp window_loop, - ConversionPatternRewriter* rewriter) const { + lmhlo::ReduceWindowOp xla_reduce_window_op, scf::ParallelOp output_loop, + scf::ParallelOp window_loop, ConversionPatternRewriter* rewriter) const { rewriter->setInsertionPointToStart(window_loop.getBody()); auto loc = xla_reduce_window_op.getLoc(); @@ -481,12 +480,12 @@ class ReduceWindowOpConverter // initialized_flag = true // output(selected_index) = scatter(output(selected_index), source(S)) class SelectAndScatterOpConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, + lmhlo::SelectAndScatterOp s_and_s_op, ArrayRef /*args*/, ConversionPatternRewriter& rewriter) const final { auto loc = s_and_s_op.getLoc(); InitializeOutput(s_and_s_op, &rewriter); @@ -515,7 +514,7 @@ class SelectAndScatterOpConverter } private: - void InitializeOutput(xla_lhlo::SelectAndScatterOp s_and_s_op, + void InitializeOutput(lmhlo::SelectAndScatterOp s_and_s_op, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value init_value = b->create(loc, s_and_s_op.init_value()); @@ -533,7 +532,7 @@ class SelectAndScatterOpConverter SmallVector window_ivs; scf::ForOp inner_loop; }; - WindowLoops InsertWindowLoops(xla_lhlo::SelectAndScatterOp s_and_s_op, + WindowLoops InsertWindowLoops(lmhlo::SelectAndScatterOp s_and_s_op, scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -598,7 +597,7 @@ class SelectAndScatterOpConverter SmallVector ivs_val_flag_; }; - SmallVector SelectIvs(xla_lhlo::SelectAndScatterOp s_and_s_op, + SmallVector SelectIvs(lmhlo::SelectAndScatterOp s_and_s_op, scf::ParallelOp loop_over_src, OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); @@ -636,9 +635,10 @@ class SelectAndScatterOpConverter return window_loops.selected_ivs; } - SmallVector SelectOrInitialize( - xla_lhlo::SelectAndScatterOp s_and_s_op, ArrayRef operand_ivs, - IterArgs* ivs_val_flag, OpBuilder* b) const { + SmallVector SelectOrInitialize(lmhlo::SelectAndScatterOp s_and_s_op, + ArrayRef operand_ivs, + IterArgs* ivs_val_flag, + OpBuilder* b) const { auto loc = s_and_s_op.getLoc(); Value true_i1 = b->create( loc, b->getI1Type(), b->getIntegerAttr(b->getI1Type(), 1)); @@ -707,9 +707,9 @@ struct LhloLegalizeToParallelLoops ConversionTarget target(getContext()); target.addLegalDialect(); - target.addIllegalOp(); + scf::SCFDialect, LmhloDialect>(); + target.addIllegalOp(); if (failed(applyPartialConversion(func, target, patterns))) { signalPassFailure(); @@ -727,5 +727,5 @@ static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-parallel-loops", "Legalize from LHLO dialect to parallel loops."); -} // namespace xla_lhlo +} // namespace lmhlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc index 03df2bc8bf1..95aa403c874 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_legalize_to_linalg.cc @@ -131,9 +131,9 @@ class PointwiseToLinalgConverter : public OpConversionPattern { loc, opResultTypes, args, args_count, results_count, indexing_maps, GetNParallelLoopsAttrs(nloops), [&](OpBuilder& nestedBuilder, Location nestedLoc, ValueRange args) { - // TODO(ravishankarm) : For now use the method in xla_lhlo namespace. + // TODO(ravishankarm) : For now use the method in lmhlo namespace. // That method needs to be moved out of there. - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + Value opResult = lmhlo::XlaOpToStdScalarOp::map( op, bodyResultTypes, llvm::to_vector<2>(args.take_front(args_count)), &rewriter); nestedBuilder.create(loc, opResult); @@ -162,8 +162,8 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { // Create two loads from the input. auto lhs = rewriter.create(loc, lhlo_op.lhs()); auto rhs = rewriter.create(loc, lhlo_op.rhs()); - // TODO(ravishankarm) : Move this method out of xla_lhlo namespace. - Value opResult = xla_lhlo::XlaOpToStdScalarOp::map( + // TODO(ravishankarm) : Move this method out of lmhlo namespace. + Value opResult = lmhlo::XlaOpToStdScalarOp::map( lhlo_op, argType.getElementType(), llvm::ArrayRef{lhs, rhs}, &rewriter); rewriter.create(loc, opResult, lhlo_op.out()); @@ -173,21 +173,21 @@ class ScalarPointwiseToStandardConverter : public OpConversionPattern { }; //===----------------------------------------------------------------------===// -// xla_lhlo.convolution conversion pattern. +// lmhlo.convolution conversion pattern. //===----------------------------------------------------------------------===// -/// Converts xla_lhlo.convolution operation to a linalg.conv op. -struct ConvToLinalgConverter : public OpConversionPattern { +/// Converts lmhlo.convolution operation to a linalg.conv op. +struct ConvToLinalgConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; // This code has been adapted from IREE's // (https://github.com/google/iree/) mhlo -> linalg conversion. LogicalResult matchAndRewrite( - xla_lhlo::ConvOp op, ArrayRef args, + lmhlo::ConvOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { // Check validity of dimension information. - if (const xla_lhlo::ConvDimensionNumbers& dimensionNumbers = + if (const lmhlo::ConvDimensionNumbers& dimensionNumbers = op.dimension_numbers()) { const int inputSpatialRank = llvm::size(dimensionNumbers.input_spatial_dimensions()); @@ -388,14 +388,14 @@ class HloBroadcastInDimConverter }; class LhloBroadcastInDimConverter - : public OpConversionPattern { + : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::BroadcastInDimOp op, ArrayRef args, + lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const final { - xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); auto result_type = operand_adaptor.output().getType().cast(); auto result_shape = result_type.getShape(); @@ -444,9 +444,9 @@ class LhloBroadcastInDimConverter // Inserts 'linalg.reshape' if there is a size-1 dim expansion. std::pair> InsertReshapeIfNecessary( - xla_lhlo::BroadcastInDimOp op, ArrayRef args, + lmhlo::BroadcastInDimOp op, ArrayRef args, ConversionPatternRewriter& rewriter) const { - xla_lhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); + lmhlo::BroadcastInDimOp::Adaptor operand_adaptor(args); Value operand = operand_adaptor.operand(); auto operand_type = operand_adaptor.operand().getType().cast(); auto operand_shape = operand_type.getShape(); @@ -512,7 +512,7 @@ class LhloBroadcastInDimConverter return std::make_pair(operand, broadcast_dims); } - SmallVector getIndexingMaps(xla_lhlo::BroadcastInDimOp op, + SmallVector getIndexingMaps(lmhlo::BroadcastInDimOp op, ArrayRef broadcastDims, ArrayRef resultShape, MemRefType operandType, @@ -639,12 +639,12 @@ class ReshapeOpConverter : public OpConversionPattern { } }; -class IotaConverter : public OpConversionPattern { +class IotaConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::IotaOp iotaOp, ArrayRef args, + lmhlo::IotaOp iotaOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto resultMemrefType = iotaOp.getOperand().getType().dyn_cast(); @@ -680,12 +680,12 @@ class IotaConverter : public OpConversionPattern { } }; -class ConstConverter : public OpConversionPattern { +class ConstConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::ConstOp constOp, ArrayRef args, + lmhlo::ConstOp constOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = constOp.getLoc(); auto valueAttr = constOp.value().cast(); @@ -726,12 +726,12 @@ class ReverseConverter } }; -class SliceConverter : public OpConversionPattern { +class SliceConverter : public OpConversionPattern { public: - using OpConversionPattern::OpConversionPattern; + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite( - xla_lhlo::SliceOp sliceOp, ArrayRef args, + lmhlo::SliceOp sliceOp, ArrayRef args, ConversionPatternRewriter& rewriter) const final { auto loc = sliceOp.getLoc(); auto argType = @@ -763,50 +763,50 @@ class SliceConverter : public OpConversionPattern { void populateLHLOToLinalgConversionPattern(MLIRContext* context, OwningRewritePatternList* patterns) { // clang-format off - patterns->insert, + patterns->insert, ConstConverter, ConvToLinalgConverter, IotaConverter, LhloBroadcastInDimConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, // TODO(ataei): Remove this pattern, CopyOp is folded away. - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - PointwiseToLinalgConverter, - ReshapeOpConverter, - ReverseConverter, - ScalarPointwiseToStandardConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + PointwiseToLinalgConverter, + ReshapeOpConverter, + ReverseConverter, + ScalarPointwiseToStandardConverter, SliceConverter >(context); // clang-format on } // Converts LHLO ops to Linalg generic. -// Sample result for xla_lhlo::AddOp. +// Sample result for lmhlo::AddOp. // -// "xla_lhlo.add"(%arg1, %arg2, %out) : +// "lmhlo.add"(%arg1, %arg2, %out) : // (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () // // will be converted to @@ -854,14 +854,14 @@ struct HloLegalizeToLinalg } // namespace -namespace xla_lhlo { +namespace lmhlo { std::unique_ptr> createLegalizeLhloToLinalgPass() { return absl::make_unique(); } static PassRegistration legalize_lhlo_pass( "lhlo-legalize-to-linalg", "Legalize from LHLO dialect to Linalg dialect"); -} // namespace xla_lhlo +} // namespace lmhlo namespace mhlo { diff --git a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir index a5559357bdc..aa5d800b82b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/hlo-legalize-to-lhlo.mlir @@ -7,7 +7,7 @@ func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_result = "mhlo.exponential"(%tensor_operand) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} + // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) {some_attr_1 = "exp.1", some_attr_2 = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -18,10 +18,10 @@ func @return_func(%arg0: tensor<4xf32>) -> tensor<4xf32> { return %arg0 : tensor<4xf32> } // PRE: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) -// PRE-NEXT: "xla_lhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () +// PRE-NEXT: "lmhlo.copy"(%[[ARG0]], %[[RESULT]]) : ([[TYPE]], [[TYPE]]) -> () // PRE-NEXT: return // ESC: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] -// ESC-NOT: "xla_lhlo.copy" +// ESC-NOT: "lmhlo.copy" // ESC-NEXT: return %[[ARG0]] // ----- @@ -38,20 +38,20 @@ func @func_op_long(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> // PRE: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>, %[[RESULT:.*]]: memref<4xf32>) // ESC: (%[[NEW_ARG0:.*]]: memref<4xf32>, %[[NEW_ARG1:.*]]: memref<4xf32>) -> memref<4xf32> // BOTH-NEXT: %[[MAX_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) +// BOTH-NEXT: "lmhlo.maximum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MAX_RESULT]]) // BOTH-NEXT: %[[ADD_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) +// BOTH-NEXT: "lmhlo.add"(%[[NEW_ARG0]], %[[MAX_RESULT]], %[[ADD_RESULT]]) // BOTH-NEXT: dealloc %[[MAX_RESULT]] : memref<4xf32> // BOTH-NEXT: %[[MIN_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) +// BOTH-NEXT: "lmhlo.minimum"(%[[NEW_ARG0]], %[[NEW_ARG1]], %[[MIN_RESULT]]) // BOTH-NEXT: %[[SUB_RESULT:.*]] = alloc() : memref<4xf32> -//  BOTH-NEXT: "xla_lhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) +//  BOTH-NEXT: "lmhlo.subtract"(%[[NEW_ARG1]], %[[MIN_RESULT]], %[[SUB_RESULT]]) // BOTH-NEXT: dealloc %[[MIN_RESULT]] : memref<4xf32> // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<4xf32> -// BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) +// BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %[[SUB_RESULT]], %[[MUL_RESULT]]) // BOTH-NEXT: dealloc %[[SUB_RESULT]] : memref<4xf32> // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<4xf32> -// PRE-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () +// PRE-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) : (memref<4xf32>, memref<4xf32>) -> () // PRE-NEXT: dealloc %[[MUL_RESULT]] : memref<4xf32> // PRE-NEXT: return // ESC-NEXT: return %[[MUL_RESULT]] : memref<4xf32> @@ -67,14 +67,14 @@ func @fusion(%multiplier: memref<2x2xf32>, %summand_1: memref<2x2xf32>, %tensor_summand_2 = tensor_load %summand_2 : memref<2x2xf32> %sum = "mhlo.add"(%tensor_summand_1, %tensor_summand_2) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "xla_lhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) + // BOTH-NEXT: "lmhlo.add"(%{{.*}}, %{{.*}}, %[[ADD_RESULT]]) // BOTH-NEXT: %[[MUL_RESULT:.*]] = alloc() : memref<2x2xf32> %tensor_multiplier = tensor_load %multiplier : memref<2x2xf32> %tensor_result = "mhlo.multiply"(%sum, %tensor_multiplier) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH-NEXT: "xla_lhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) + // BOTH-NEXT: "lmhlo.multiply"(%[[ADD_RESULT]], %{{.*}}, %[[MUL_RESULT]]) // BOTH-NEXT: dealloc %[[ADD_RESULT]] : memref<2x2xf32> - // BOTH-NEXT: "xla_lhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) + // BOTH-NEXT: "lmhlo.copy"(%[[MUL_RESULT]], %[[RESULT]]) tensor_store %tensor_result, %result : memref<2x2xf32> // BOTH-NEXT: dealloc %[[MUL_RESULT]] : memref<2x2xf32> // BOTH-NEXT: return @@ -88,7 +88,7 @@ func @copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.copy"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -100,7 +100,7 @@ func @exp(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.exponential"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.exponential"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.exponential"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -112,7 +112,7 @@ func @log(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.log"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.log"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.log"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -127,7 +127,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.select"(%tensor_pred, %tensor_lhs, %tensor_rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) + // BOTH: "lmhlo.select"(%{{.*}}, %{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -141,7 +141,7 @@ func @compare(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2x %tensor_result = "mhlo.compare"(%tensor_lhs, %tensor_rhs) {comparison_direction = "EQ"} : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xi1> - // BOTH: "xla_lhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} + // BOTH: "lmhlo.compare"(%{{.*}}, %{{.*}}, %{{.*}}) {comparison_direction = "EQ"} tensor_store %tensor_result, %result : memref<2x2xi1> return } @@ -154,7 +154,7 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) { %tensor_result = "mhlo.broadcast_in_dim"(%tensor_operand) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<5xf32>) -> tensor<10x5xf32> - // BOTH: "xla_lhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} + // BOTH: "lmhlo.broadcast_in_dim"(%{{.*}}, %{{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} tensor_store %tensor_result, %result : memref<10x5xf32> return } @@ -205,12 +205,12 @@ func @dyn_broadcast(%operand: memref) { // BOTH: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]] // BOTH: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index - // BOTH: %[[TRANSFORMED_MEMREF:.*]] = xla_lhlo.dynamic_memref_cast + // BOTH: %[[TRANSFORMED_MEMREF:.*]] = lmhlo.dynamic_memref_cast // BOTH-SAME: %[[OPERAND]](%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]) // BOTH-SAME: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]] // BOTH-SAME: : memref -> memref - // BOTH: "xla_lhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { + // BOTH: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) { // BOTH-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64> // BOTH-SAME: } : (memref, memref) -> () @@ -229,7 +229,7 @@ func @complex(%real: memref<2x2xf32>, %tensor_imag = tensor_load %imag : memref<2x2xf32> %tensor_result = "mhlo.complex"(%tensor_real, %tensor_imag) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xcomplex> - // BOTH: "xla_lhlo.complex"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.complex"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xcomplex> return } @@ -241,7 +241,7 @@ func @real(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> %tensor_result = "mhlo.real"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.real"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.real"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -253,7 +253,7 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xcomplex> %tensor_result = "mhlo.imag"(%tensor_operand) : (tensor<2x2xcomplex>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.imag"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.imag"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -264,7 +264,7 @@ func @imag(%operand: memref<2x2xcomplex>, %result: memref<2x2xf32>) { func @iota(%result: memref<10xi32>) { %tensor_result = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<10xi32> - // BOTH: "xla_lhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} + // BOTH: "lmhlo.iota"(%{{.*}}) {iota_dimension = 0 : i64} tensor_store %tensor_result, %result : memref<10xi32> return } @@ -276,7 +276,7 @@ func @abs(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.abs"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.abs"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.abs"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -288,7 +288,7 @@ func @ceil(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.ceil"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.ceil"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.ceil"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -300,7 +300,7 @@ func @convert(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.convert"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.copy"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.copy"(%{{.*}}, %{{.*}}) // BOTH-NOT: tensor_store tensor_store %tensor_result, %result : memref<2x2xf32> return @@ -313,7 +313,7 @@ func @cos(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.cosine"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.cosine"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.cosine"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -325,7 +325,7 @@ func @neg(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.negate"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.negate"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.negate"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -337,7 +337,7 @@ func @rsqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.rsqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.rsqrt"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.rsqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -349,7 +349,7 @@ func @sign(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.sign"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.sign"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.sign"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -361,7 +361,7 @@ func @sqrt(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.sqrt"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.sqrt"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.sqrt"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -373,7 +373,7 @@ func @tanh(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) { %tensor_operand = tensor_load %operand : memref<2x2xf32> %tensor_result = "mhlo.tanh"(%tensor_operand) : (tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.tanh"(%{{.*}}, %{{.*}}) + // BOTH: "lmhlo.tanh"(%{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -386,7 +386,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x %tensor_rhs = tensor_load %rhs : memref<2x2xf32> %tensor_result = "mhlo.remainder"(%tensor_lhs, %tensor_rhs) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> - // BOTH: "xla_lhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) + // BOTH: "lmhlo.remainder"(%{{.*}}, %{{.*}}, %{{.*}}) tensor_store %tensor_result, %result : memref<2x2xf32> return } @@ -412,7 +412,7 @@ func @add_dyn(%lhs: tensor, %rhs: tensor) { // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "xla_lhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () + // BOTH: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref, memref, memref) -> () return } @@ -437,7 +437,7 @@ func @tanh_dyn(%arg0: tensor) { // BOTH: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64> // BOTH: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index // BOTH: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]]) - // BOTH: "xla_lhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () + // BOTH: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref, memref) -> () return } @@ -448,10 +448,10 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { // PRE-SAME: (%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[TYPE]]) // ESC-SAME: (%[[ARG0:.*]]: [[TYPE:.*]]) -> [[TYPE]] // BOTH-NEXT: %[[ALLOC:.*]] = alloc -// BOTH: "xla_lhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// BOTH: "lmhlo.dot"(%[[ARG0]], %[[ARG0]], %[[ALLOC]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () %dot = "mhlo.dot"(%arg0, %arg0) : (tensor<1024x1024xf32>, tensor<1024x1024xf32>) -> tensor<1024x1024xf32> -// PRE: "xla_lhlo.copy"(%[[ALLOC]], %[[RESULT]]) +// PRE: "lmhlo.copy"(%[[ALLOC]], %[[RESULT]]) // ESC: return %[[ALLOC]] return %dot : tensor<1024x1024xf32> } @@ -462,7 +462,7 @@ func @dot(%arg0: tensor<1024x1024xf32>) -> tensor<1024x1024xf32> { func @conv(%input: tensor<3x5x5x3xf32>, %filter : tensor<2x2x3x4xf32>) -> tensor<3x5x5x4xf32> { %c0 = constant 0 : index // BOTH: %[[OUT:.*]] = alloc() : memref<3x5x5x4xf32> - // BOTH: "xla_lhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) + // BOTH: "lmhlo.convolution"(%{{.+}}, %{{.+}}, %[[OUT]]) // BOTH-SAME: padding = dense<[ // BOTH-SAME: [0, 1], [0, 1]]> : tensor<2x2xi64> // BOTH-SAME: rhs_dilation = dense<[1, 2]> diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir index 3d3f802dcb1..6d7992cb868 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-copy-removal.mlir @@ -3,10 +3,10 @@ // CHECK-LABEL: func @remove_simple func @remove_simple(%arg0: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -14,9 +14,9 @@ func @remove_simple(%arg0: memref<2x2xf32>) { // CHECK-LABEL: func @remove_without_dealloc func @remove_without_dealloc(%arg0: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%0, %arg0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -24,22 +24,22 @@ func @remove_without_dealloc(%arg0: memref<2x2xf32>) { // CHECK-LABEL: func @replace_dependency func @replace_dependency(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- // CHECK-LABEL: func @keep_copies func @keep_copies(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>) { - // CHECK-NEXT: "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.terminator"() : () -> () - "xla_lhlo.terminator"() : () -> () + // CHECK-NEXT: "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%arg0, %arg1) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -50,14 +50,14 @@ func @must_not_be_removed(%arg0: memref<2x2xf32>, %arg2: memref<2x2xf32>) { // CHECK-NEXT: %[[ALLOC:.*]] = alloc() {temp = true} : memref<2x2xf32> %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %[[ALLOC]]) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.copy"(%[[ALLOC]], %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -67,13 +67,13 @@ func @must_be_removed_first(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // ----- @@ -83,11 +83,11 @@ func @must_be_removed_second(%arg0: memref<2x2xf32>, %arg1: memref<2x2xf32>, %arg2: memref<2x2xf32>) { %0 = alloc() {temp = true} : memref<2x2xf32> - // CHECK-NEXT: "xla_lhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - // CHECK-NEXT: "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () - "xla_lhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg0, %0) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.copy"(%0, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + // CHECK-NEXT: "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%arg1, %arg2) : (memref<2x2xf32>, memref<2x2xf32>) -> () dealloc %0 : memref<2x2xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-select-and-scatter.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-select-and-scatter.mlir index 2aa6378d2a9..b110d8d257e 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-select-and-scatter.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-select-and-scatter.mlir @@ -10,18 +10,18 @@ func @select_and_scatter(%arg: memref<112x112xf32>, %src: memref<56x56xf32>, %init: memref, %result: memref<112x112xf32>) { - "xla_lhlo.select_and_scatter"(%arg, %src, %init, %result) ( { + "lmhlo.select_and_scatter"(%arg, %src, %init, %result) ( { // select ^bb0(%lhs: memref, %rhs: memref, %pred: memref): - "xla_lhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : + "lmhlo.compare"(%lhs, %rhs, %pred) {comparison_direction = "GE"} : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }, { // scatter ^bb0(%lhs: memref, %rhs: memref, %out: memref): - "xla_lhlo.add"(%lhs, %rhs, %out) : + "lmhlo.add"(%lhs, %rhs, %out) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }) { padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -29,7 +29,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, } : (memref<112x112xf32>, memref<56x56xf32>, memref, memref<112x112xf32>) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } // CHECK-LABEL: func @select_and_scatter( // CHECK-SAME: [[ARG_BUF:%.*]]: memref<112x112xf32>, @@ -121,7 +121,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: store [[SEL_VAL]], [[SEL_VAL_BUF]][] : memref // Compute PRED. - // CHECK: "xla_lhlo.compare"( + // CHECK: "lmhlo.compare"( // CHECK-SAME: [[ARG_ELEM_BUF]], [[SEL_VAL_BUF]], [[PRED_BUF]]) // CHECK: [[PRED:%.*]] = load [[PRED_BUF]][] : memref @@ -182,7 +182,7 @@ func @select_and_scatter(%arg: memref<112x112xf32>, // CHECK: store [[CUR_RES]], [[CUR_RES_BUF]][] : memref // Compute scatter value. -// CHECK: "xla_lhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : +// CHECK: "lmhlo.add"([[SRC_ELEM_BUF]], [[CUR_RES_BUF]], [[RES_BUF]]) : // CHECK-SAME: (memref, memref, memref) -> () // CHECK: [[RES:%.*]] = load [[RES_BUF]][] : memref diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir index 1068d1a94a0..87818045993 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-affine.mlir @@ -14,7 +14,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, // CHECK-NEXT: %[[MIN:.*]] = select %[[MIN_PREDICATE]], %[[LHS]], %[[RHS]] : f32 // CHECK-NEXT: affine.store %[[MIN]], %{{.*}}[%[[I]], %[[J]], %[[K]], %[[L]]] : memref<4x3x2x1xf32> // CHECK: return - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<4x3x2x1xf32>, memref<4x3x2x1xf32>, memref<4x3x2x1xf32>) -> () return } @@ -24,7 +24,7 @@ func @min_op(%lhs: memref<4x3x2x1xf32>, %rhs: memref<4x3x2x1xf32>, func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: addf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -32,7 +32,7 @@ func @float_add_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: addi %{{.*}}, %{{.*}} : i32 - "xla_lhlo.add"(%lhs, %rhs, %result) {name = "add.1"} + "lmhlo.add"(%lhs, %rhs, %result) {name = "add.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -42,7 +42,7 @@ func @int_add_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: and %{{.*}}, %{{.*}} : i32 - "xla_lhlo.and"(%lhs, %rhs, %result) {name = "and.1"} + "lmhlo.and"(%lhs, %rhs, %result) {name = "and.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -52,7 +52,7 @@ func @int_and_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: divf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -60,7 +60,7 @@ func @float_div_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_div_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: divi_signed %{{.*}}, %{{.*}} : i32 - "xla_lhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} + "lmhlo.divide"(%lhs, %rhs, %result) {name = "div.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -71,7 +71,7 @@ func @float_max_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "ogt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -81,7 +81,7 @@ func @int_max_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "sgt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} + "lmhlo.maximum"(%lhs, %rhs, %result) {name = "max.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -92,7 +92,7 @@ func @float_min_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: %[[CHECK:.*]] = cmpf "olt", %[[ONE:.*]], %[[TWO:.*]] : f32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : f32 - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -102,7 +102,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: %[[CHECK:.*]] = cmpi "slt", %[[ONE:.*]], %[[TWO:.*]] : i32 // CHECK: select %[[CHECK]], %[[ONE]], %[[TWO]] : i32 - "xla_lhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} + "lmhlo.minimum"(%lhs, %rhs, %result) {name = "min.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -112,7 +112,7 @@ func @int_min_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: mulf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -121,7 +121,7 @@ func @float_mul_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: muli %{{.*}}, %{{.*}} : i32 - "xla_lhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} + "lmhlo.multiply"(%lhs, %rhs, %result) {name = "mul.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -131,7 +131,7 @@ func @int_mul_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, %result: memref<7xf32>) -> () { // CHECK: subf %{{.*}}, %{{.*}} : f32 - "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xf32>, memref<7xf32>, memref<7xf32>) -> () return } @@ -139,7 +139,7 @@ func @float_sub_op(%lhs: memref<7xf32>, %rhs: memref<7xf32>, func @int_sub_op(%lhs: memref<7xi32>, %rhs: memref<7xi32>, %result: memref<7xi32>) -> () { // CHECK: subi %{{.*}}, %{{.*}} : i32 - "xla_lhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} + "lmhlo.subtract"(%lhs, %rhs, %result) {name = "sub.1"} : (memref<7xi32>, memref<7xi32>, memref<7xi32>) -> () return } @@ -158,7 +158,7 @@ func @float_dot_op(%lhs: memref<7x3xf32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addf %[[MULT]], %[[RESULT]] : f32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xf32> // CHECK: return - "xla_lhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) : (memref<7x3xf32>, memref<3x4xf32>, memref<7x4xf32>) -> () return } @@ -175,7 +175,7 @@ func @int_dot_op(%lhs: memref<7x3xi32>, %rhs: // CHECK-NEXT: %[[ADD:.*]] = addi %[[MULT]], %[[RESULT]] : i32 // CHECK-NEXT: affine.store %[[ADD]], %{{.*}}[%[[I]], %[[J]]] : memref<7x4xi32> // CHECK: return - "xla_lhlo.dot"(%lhs, %rhs, %result) : + "lmhlo.dot"(%lhs, %rhs, %result) : (memref<7x3xi32>, memref<3x4xi32>, memref<7x4xi32>) -> () return } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-gpu.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-gpu.mlir index e996581c593..02ad3653639 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-gpu.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-gpu.mlir @@ -3,11 +3,11 @@ func @reduce(%arg: memref<100x10xf32>, %init: memref, %result: memref<100xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<100x10xf32>, memref, memref<100xf32>) -> () return @@ -25,7 +25,7 @@ func @reduce(%arg: memref<100x10xf32>, // CHECK: scf.for %[[IDX1:.*]] = %[[LB]] to %[[UB]] step %[[STEP]] { // CHECK: %[[LHS:.*]] = linalg.slice %[[ARG2]][%[[IDX]]] : memref<100xf32>, index, memref // CHECK: %[[RHS:.*]] = linalg.slice %[[ARG0]][%[[IDX]], %[[IDX1]]] : memref<100x10xf32>, index, index, memref -// CHECK: "xla_lhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () +// CHECK: "lmhlo.add"(%[[LHS]], %[[RHS]], %[[LHS]]) : (memref, memref, memref) -> () // CHECK: } // CHECK: gpu.terminator // CHECK: } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir index 8ebfb6b5ce1..6981466dc46 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-linalg.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: func @element_wise func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -19,7 +19,7 @@ func @element_wise(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, func @element_wise_with_dynamic_shape(%lhs: memref, %rhs: memref, %result: memref) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return } @@ -33,7 +33,7 @@ func @element_wise_with_dynamic_shape(%lhs: memref, // CHECK-LABEL: func @element_wise_scalar func @element_wise_scalar(%lhs: memref, %rhs: memref, %result: memref) { - "xla_lhlo.add"(%lhs, %rhs, %result) + "lmhlo.add"(%lhs, %rhs, %result) : (memref, memref, memref) -> () return } @@ -48,7 +48,7 @@ func @element_wise_scalar(%lhs: memref, %rhs: memref, // CHECK-LABEL: func @minf func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.minimum"(%lhs, %rhs, %result) + "lmhlo.minimum"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -63,7 +63,7 @@ func @minf(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @maxi func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.maximum"(%lhs, %rhs, %result) + "lmhlo.maximum"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -78,7 +78,7 @@ func @maxi(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @and func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.and"(%lhs, %rhs, %result) + "lmhlo.and"(%lhs, %rhs, %result) : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -91,7 +91,7 @@ func @and(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.exponential"(%input, %result) + "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -104,7 +104,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @log func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.log"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -116,7 +116,7 @@ func @log(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @copy func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { - "xla_lhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () + "lmhlo.copy"(%in, %out) : (memref<2x4x8xf32>, memref<2x4x8xf32>) -> () return } // CHECK: linalg.generic @@ -128,7 +128,7 @@ func @copy(%in: memref<2x4x8xf32>, %out: memref<2x4x8xf32>) { // CHECK-LABEL: func @float_cmp func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "EQ"} : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xi1>) -> () return } @@ -142,7 +142,7 @@ func @float_cmp(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @int_cmp func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, %result: memref<2x2xi1>) { - "xla_lhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} + "lmhlo.compare"(%lhs, %rhs, %result) {comparison_direction = "LT"} : (memref<2x2xi32>, memref<2x2xi32>, memref<2x2xi1>) -> () return } @@ -156,7 +156,7 @@ func @int_cmp(%lhs: memref<2x2xi32>, %rhs: memref<2x2xi32>, // CHECK-LABEL: func @select func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.select"(%pred, %lhs, %rhs, %result) + "lmhlo.select"(%pred, %lhs, %rhs, %result) : (memref<2x2xi1>, memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -170,7 +170,7 @@ func @select(%pred: memref<2x2xi1>, %lhs: memref<2x2xf32>, // CHECK: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @iota func @iota(%out: memref<7x10xf32>) { - "xla_lhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () + "lmhlo.iota"(%out) {iota_dimension = 1 : i64} : (memref<7x10xf32>) -> () return } // CHECK: linalg.indexed_generic @@ -186,7 +186,7 @@ func @iota(%out: memref<7x10xf32>) { // CHECK-DAG: #[[RESULT_MAP:.+]] = affine_map<(d0, d1, d2) -> (d0, d1, d2)> // CHECK-LABEL: func @broadcast_scalar func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { - "xla_lhlo.broadcast"(%operand, %result) { + "lmhlo.broadcast"(%operand, %result) { broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> } : (memref, memref<4x2x1xf32>) -> () return @@ -203,7 +203,7 @@ func @broadcast_scalar(%operand: memref, %result: memref<4x2x1xf32>) { // CHECK-LABEL: func @broadcast func @broadcast(%operand: memref<4x?x16xf32>, %result: memref<4x2x1x4x?x16xf32>) { - "xla_lhlo.broadcast"(%operand, %result) { + "lmhlo.broadcast"(%operand, %result) { broadcast_sizes = dense<[4, 2, 1]> : tensor<3xi64> } : (memref<4x?x16xf32>, memref<4x2x1x4x?x16xf32>) -> () return @@ -220,7 +220,7 @@ func @broadcast(%operand: memref<4x?x16xf32>, // CHECK-LABEL: func @dynamic_broadcast_in_dim func @dynamic_broadcast_in_dim(%operand: memref, %result: memref) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[4,0,2]> : tensor<3xi64> } : (memref, memref) -> () return @@ -237,7 +237,7 @@ func @dynamic_broadcast_in_dim(%operand: memref, // CHECK-LABEL: func @static_broadcast_in_dim_no_expansion func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, %result: memref<5x10xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (memref<5xf32>, memref<5x10xf32>) -> () return @@ -255,7 +255,7 @@ func @static_broadcast_in_dim_no_expansion(%operand: memref<5xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_expansion func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, %result: memref<5x10x100xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[2, 0]> : tensor<2xi64> } : (memref<1x5xf32>, memref<5x10x100xf32>) -> () return @@ -274,7 +274,7 @@ func @static_broadcast_in_dim_expansion(%operand: memref<1x5xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_scalar func @static_broadcast_in_dim_scalar(%operand: memref, %result: memref<5x10xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[]> : tensor<0xi64> } : (memref, memref<5x10xf32>) -> () return @@ -291,7 +291,7 @@ func @static_broadcast_in_dim_scalar(%operand: memref, // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_one func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, %result: memref<1x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[0]> : tensor<1xi64> } : (memref<1xf32>, memref<1x5xf32>) -> () return @@ -307,7 +307,7 @@ func @static_broadcast_in_dim_with_one_to_one(%operand: memref<1xf32>, // CHECK-LABEL: func @static_broadcast_in_dim_with_one_to_many func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, %result: memref<5x5xf32>) { - "xla_lhlo.broadcast_in_dim"(%operand, %result) { + "lmhlo.broadcast_in_dim"(%operand, %result) { broadcast_dimensions = dense<[1]> : tensor<1xi64> } : (memref<1xf32>, memref<5x5xf32>) -> () return @@ -323,7 +323,7 @@ func @static_broadcast_in_dim_with_one_to_many(%operand: memref<1xf32>, // CHECK-LABEL: func @constant func @constant(%value: memref) { - "xla_lhlo.constant"(%value) { + "lmhlo.constant"(%value) { value = dense<10> : tensor } : (memref) -> () return @@ -335,7 +335,7 @@ func @constant(%value: memref) { // CHECK-LABEL: func @absf func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.abs"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -348,7 +348,7 @@ func @absf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @absi func @absi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.abs"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -364,7 +364,7 @@ func @absi(%input: memref<2x2xi32>, // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -376,7 +376,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i32_to_f32 func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -389,7 +389,7 @@ func @convert_i32_to_f32(%input: memref<2x2xi32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i16_to_i32 func @convert_i16_to_i32(%input: memref<2x2xi16>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi16>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -401,7 +401,7 @@ func @convert_i16_to_i32(%input: memref<2x2xi16>, // CHECK-LABEL: func @convert_i32_to_i16 func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi16>) -> () return } // CHECK: linalg.generic @@ -413,7 +413,7 @@ func @convert_i32_to_i16(%input: memref<2x2xi32>, %result: memref<2x2xi16>) { // CHECK-LABEL: func @convert_f32_to_f64 func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf64>) -> () return } // CHECK: linalg.generic @@ -425,7 +425,7 @@ func @convert_f32_to_f64(%input: memref<2x2xf32>, %result: memref<2x2xf64>) { // CHECK-LABEL: func @convert_f64_to_f32 func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf64>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -437,7 +437,7 @@ func @convert_f64_to_f32(%input: memref<2x2xf64>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_i32_to_i32 func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -448,7 +448,7 @@ func @convert_i32_to_i32(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @convert_f32_to_f32 func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -459,7 +459,7 @@ func @convert_f32_to_f32(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @convert_f32_to_i32 func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { - "xla_lhlo.convert"(%input, %result) + "lmhlo.convert"(%input, %result) : (memref<2x2xf32>, memref<2x2xi32>) -> () return } @@ -472,7 +472,7 @@ func @convert_f32_to_i32(%input: memref<2x2xf32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -485,7 +485,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sine"(%input, %result) + "lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -498,7 +498,7 @@ func @sin(%input: memref<2x2xf32>, // CHECK-LABEL: func @negf func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.negate"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -510,7 +510,7 @@ func @negf(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @negi func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { - "xla_lhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.negate"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } // CHECK: linalg.generic @@ -524,7 +524,7 @@ func @negi(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @rem func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.remainder"(%lhs, %rhs, %result) + "lmhlo.remainder"(%lhs, %rhs, %result) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -537,7 +537,7 @@ func @remainder(%lhs: memref<2x2xf32>, %rhs: memref<2x2xf32>, // CHECK-LABEL: func @rsqrt func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.rsqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -549,7 +549,7 @@ func @rsqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sign func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sign"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -562,7 +562,7 @@ func @sign(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sqrt func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sqrt"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -574,7 +574,7 @@ func @sqrt(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @tanh func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.tanh"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } // CHECK: linalg.generic @@ -588,7 +588,7 @@ func @tanh(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @complex(%real: memref<2x2xf32>, %imag: memref<2x2xf32>, %cplx: memref<2x2xcomplex>) { - "xla_lhlo.complex"(%real, %imag, %cplx) + "lmhlo.complex"(%real, %imag, %cplx) : (memref<2x2xf32>, memref<2x2xf32>, memref<2x2xcomplex>) -> () return } @@ -602,7 +602,7 @@ func @complex(%real: memref<2x2xf32>, // CHECK-LABEL: func @real func @real(%cplx: memref<2x2xcomplex>, %real: memref<2x2xf32>) { - "xla_lhlo.real"(%cplx, %real) + "lmhlo.real"(%cplx, %real) : (memref<2x2xcomplex>, memref<2x2xf32>) -> () return } @@ -616,7 +616,7 @@ func @real(%cplx: memref<2x2xcomplex>, // CHECK-LABEL: func @imag func @imag(%cplx: memref<2x2xcomplex>, %imag: memref<2x2xf32>) { - "xla_lhlo.imag"(%cplx, %imag) + "lmhlo.imag"(%cplx, %imag) : (memref<2x2xcomplex>, memref<2x2xf32>) -> () return } @@ -629,7 +629,7 @@ func @imag(%cplx: memref<2x2xcomplex>, // CHECK: func @slice(%[[IN:.*]]: memref, %[[OUT:.*]]: memref) func @slice(%operand: memref, %result: memref) { - "xla_lhlo.slice"(%operand, %result) { + "lmhlo.slice"(%operand, %result) { start_indices = dense<[0,1]> : tensor<2xi64>, limit_indices = dense<[2,3]> : tensor<2xi64>, strides = dense<[1,1]> : tensor<2xi64> @@ -653,7 +653,7 @@ func @slice(%operand: memref, %result: memref) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2) -> (d2)> // CHECK-LABEL: func @reshape_3D_2D func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x1x42xi32>, memref<12x42xi32>) -> () return } @@ -666,7 +666,7 @@ func @reshape_3D_2D(%arg0: memref<12x1x42xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d1, d2, d3)> // CHECK-LABEL: func @reshape_4D_2D func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x42x1x1xi32>, memref<12x42xi32>) -> () return } @@ -679,7 +679,7 @@ func @reshape_4D_2D(%arg0: memref<12x42x1x1xi32>, %arg1 : memref<12x42xi32>) { // CHECK-DAG: #[[MAP2:.*]] = affine_map<(d0, d1, d2, d3) -> (d2, d3)> // CHECK-LABEL: func @reshape_2D_4D func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { - "xla_lhlo.reshape"(%arg0, %arg1) + "lmhlo.reshape"(%arg0, %arg1) : (memref<12x42xi32>, memref<12x1x42x1xi32>) -> () return } @@ -692,7 +692,7 @@ func @reshape_2D_4D(%arg0: memref<12x42xi32>, %arg1 : memref<12x1x42x1xi32>) { // CHECK-DAG: #[[RESULT_MAP:.*]] = affine_map<(d0, d1) -> (d0, d1)> // CHECK-LABEL: func @reverse func @reverse(%arg0: memref<2x3xf32>, %arg1: memref<2x3xf32>) { - "xla_lhlo.reverse"(%arg0, %arg1) { + "lmhlo.reverse"(%arg0, %arg1) { dimensions = dense<1> : tensor<1xi64> } : (memref<2x3xf32>, memref<2x3xf32>) -> () return @@ -710,15 +710,15 @@ func @conv(%input: memref<3x5x5x3xf32>, %filter: memref<2x2x3x4xf32>, %output: m // CHECK-SAME: padding = dense<{{\[\[}}0, 1], [0, 1]]> : tensor<2x2xi64> // CHECK-SAME: strides = [2, 1]} // With all atributes explicitly specified. - "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, rhs_dilation = dense<[1, 2]> : tensor<2xi64>, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () // Dilation left unspecified, sets default dilation since linalg expects it. // CHECK: linalg.conv(%{{.+}}, %{{.+}}, %{{.+}}) // CHECK-SAME: dilations = [1, 1] // Padding is not set if it's zero. // CHECK-NOT: padding - "xla_lhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.convolution"(%filter, %input, %0) {batch_group_count = 1 : i64, dimension_numbers = {input_batch_dimension = 0 : i64, input_feature_dimension = 3 : i64, input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>, kernel_input_feature_dimension = 2 : i64, kernel_output_feature_dimension = 3 : i64, kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>, output_batch_dimension = 0 : i64, output_feature_dimension = 3 : i64, output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>}, feature_group_count = 1 : i64, window_strides = dense<[2, 1]> : tensor<2xi64>} : (memref<2x2x3x4xf32>, memref<3x5x5x3xf32>, memref<3x5x5x4xf32>) -> () - "xla_lhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%0, %output) : (memref<3x5x5x4xf32>, memref<3x5x5x4xf32>) -> () + "lmhlo.terminator"() : () -> () } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir index a9759c0dce7..a25a508b2d3 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-llvm.mlir @@ -2,7 +2,7 @@ // CHECK-LABEL: func @static_memref_cast func @static_memref_cast(%buf : memref<10x1x5xf32>) { - %0 = xla_lhlo.static_memref_cast %buf + %0 = lmhlo.static_memref_cast %buf : memref<10x1x5xf32> -> memref<10x5xf32, offset: 2, strides: [5, 1]> return } @@ -38,7 +38,7 @@ func @dynamic_memref_cast(%buf : memref) { %size_Y = constant 50 : index %stride_X = constant 1 : index %stride_Y = constant 0 : index - %0 = xla_lhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] + %0 = lmhlo.dynamic_memref_cast %buf(%size_X, %size_Y)[%stride_X, %stride_Y] : memref -> memref return } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir index a3d76ef0196..1530f59317d 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo-legalize-to-parallel-loops.mlir @@ -3,11 +3,11 @@ func @reduce(%arg: memref<100x10x5xf32>, %init: memref, %result: memref<100x5xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref<100x10x5xf32>, memref, memref<100x5xf32>) -> () return @@ -35,7 +35,7 @@ func @reduce(%arg: memref<100x10x5xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } @@ -49,11 +49,11 @@ func @reduce(%arg: memref<100x10x5xf32>, func @reduce_no_outer_loop(%arg: memref<100xf32>, %init: memref, %result: memref<1xf32>) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<100xf32>, memref, memref<1xf32>) -> () return @@ -76,7 +76,7 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] // CHECK: } @@ -88,11 +88,11 @@ func @reduce_no_outer_loop(%arg: memref<100xf32>, func @dynamic_reduce(%arg: memref, %init: memref, %result: memref) { - "xla_lhlo.reduce"(%arg, %init, %result) ( { + "lmhlo.reduce"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.add"(%lhs, %rhs, %res) + "lmhlo.add"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[1]> : tensor<1xi64>} : (memref, memref, memref) -> () return @@ -121,7 +121,7 @@ func @dynamic_reduce(%arg: memref, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.add"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } @@ -135,11 +135,11 @@ func @dynamic_reduce(%arg: memref, func @reduce_window(%arg: memref<112x112xf32>, %init: memref, %result: memref<56x56xf32>) { - "xla_lhlo.reduce_window"(%arg, %init, %result) ( { + "lmhlo.reduce_window"(%arg, %init, %result) ( { ^bb0(%lhs: memref, %rhs: memref, %res: memref): - "xla_lhlo.maximum"(%lhs, %rhs, %res) + "lmhlo.maximum"(%lhs, %rhs, %res) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () }) { padding = dense<[[0, 1], [0, 1]]> : tensor<2x2xi64>, window_dimensions = dense<[3, 3]> : tensor<2xi64>, @@ -189,7 +189,7 @@ func @reduce_window(%arg: memref<112x112xf32>, // CHECK: [[ACC_OUT_BUF:%.*]] = alloc() : memref // CHECK: store [[ELEM]], [[ELEM_BUF]][] : memref // CHECK: store [[ACC]], [[ACC_BUF]][] : memref -// CHECK: "xla_lhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) +// CHECK: "lmhlo.maximum"([[ELEM_BUF]], [[ACC_BUF]], [[ACC_OUT_BUF]]) // CHECK: [[ACC_RESULT:%.*]] = load [[ACC_OUT_BUF]][] : memref // CHECK: scf.reduce.return [[ACC_RESULT]] : f32 // CHECK: } diff --git a/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir b/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir index e793e2a5d0f..30ff9659d3b 100644 --- a/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/lhlo_ops.mlir @@ -4,7 +4,7 @@ // CHECK-LABEL: func @ceil func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -12,7 +12,7 @@ func @ceil(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.ceil"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -20,7 +20,7 @@ func @ceil(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -28,7 +28,7 @@ func @cos(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @cos func @cos(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -36,7 +36,7 @@ func @cos(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.cosine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -44,7 +44,7 @@ func @cos(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -52,7 +52,7 @@ func @sin(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @sin func @sin(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.sine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -60,7 +60,7 @@ func @sin(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.sine"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -68,7 +68,7 @@ func @sin(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @add_memrefs func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.add"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -76,7 +76,7 @@ func @add_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1 // CHECK-LABEL: func @abs_memref func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.abs"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -84,7 +84,7 @@ func @abs_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @convert_memref func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { - "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () + "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<10xi32>) -> () return } @@ -92,7 +92,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<10xi32>) -> () { func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () + "lmhlo.convert"(%in, %out) : (memref<10xf32>, memref<9xi32>) -> () return } @@ -100,7 +100,7 @@ func @convert_memref(%in: memref<10xf32>, %out: memref<9xi32>) -> () { // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xf32>, memref<2x2xf32>) -> () return } @@ -108,7 +108,7 @@ func @exp(%input: memref<2x2xf32>, %result: memref<2x2xf32>) { // CHECK-LABEL: func @exp func @exp(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xcomplex>, memref<2x2xcomplex>) -> () return } @@ -116,7 +116,7 @@ func @exp(%input: memref<2x2xcomplex>, %result: memref<2x2xcomplex>) { func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () + "lmhlo.exponential"(%input, %result) : (memref<2x2xi32>, memref<2x2xi32>) -> () return } @@ -124,7 +124,7 @@ func @exp(%input: memref<2x2xi32>, %result: memref<2x2xi32>) { // CHECK-LABEL: func @log_memref func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.log"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -132,7 +132,7 @@ func @log_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @log_memref func @log_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.log"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.log"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -140,7 +140,7 @@ func @log_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.log"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -148,7 +148,7 @@ func @log_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @neg_memref func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.negate"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -156,7 +156,7 @@ func @neg_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @rsqrt_memref func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -164,7 +164,7 @@ func @rsqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @rsqrt_memref func @rsqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -172,7 +172,7 @@ func @rsqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.rsqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -180,7 +180,7 @@ func @rsqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @sqrt_memref func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -188,7 +188,7 @@ func @sqrt_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @sqrt_memref func @sqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.sqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -196,7 +196,7 @@ func @sqrt_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) - func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.sqrt"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -204,7 +204,7 @@ func @sqrt_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @sign_memref func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.sign"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -212,7 +212,7 @@ func @sign_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @tanh_memref func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xf32>, memref<10xf32>) -> () return } @@ -220,7 +220,7 @@ func @tanh_memref(%in: memref<10xf32>, %out: memref<10xf32>) -> () { // CHECK-LABEL: func @tanh_memref func @tanh_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) -> () { - "xla_lhlo.tanh"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xcomplex>, memref<10xcomplex>) -> () return } @@ -228,15 +228,15 @@ func @tanh_memref(%in: memref<10xcomplex>, %out: memref<10xcomplex>) - func @tanh_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.tanh"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } // ----- func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { - // expected-error@+1{{'xla_lhlo.tanh' op requires all operands to have the same type}} - "xla_lhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () + // expected-error@+1{{'lmhlo.tanh' op requires all operands to have the same type}} + "lmhlo.tanh"(%arg0, %arg1) : (memref<1xf32>, memref<2xf32>) -> () return } @@ -244,7 +244,7 @@ func @tanh_memref(%arg0: memref<1xf32>, %arg1: memref<2xf32>) -> () { // CHECK-LABEL: func @add_memref func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.add"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -252,7 +252,7 @@ func @add_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @div_memref func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.divide"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -260,7 +260,7 @@ func @div_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @max_memref func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.maximum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -268,7 +268,7 @@ func @max_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @min_memref func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.minimum"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -276,7 +276,7 @@ func @min_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @mul_memref func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.multiply"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -284,7 +284,7 @@ func @mul_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @sub_memref func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.subtract"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -292,7 +292,7 @@ func @sub_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @and_memref func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -300,7 +300,7 @@ func @and_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32 // CHECK-LABEL: func @and_memref func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -308,7 +308,7 @@ func @and_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.and"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -316,7 +316,7 @@ func @and_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @or_memref func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -324,7 +324,7 @@ func @or_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32> // CHECK-LABEL: func @or_memref func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -332,7 +332,7 @@ func @or_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) - func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.or"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -340,7 +340,7 @@ func @or_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32> // CHECK-LABEL: func @xor_memref func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32>) -> () { - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi32>, memref<10xi32>, memref<10xi32>) -> () return } @@ -348,7 +348,7 @@ func @xor_memref(%lhs: memref<10xi32>, %rhs: memref<10xi32>, %out: memref<10xi32 // CHECK-LABEL: func @xor_memref func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) -> () { - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xi1>, memref<10xi1>, memref<10xi1>) -> () return } @@ -356,7 +356,7 @@ func @xor_memref(%lhs: memref<10xi1>, %rhs: memref<10xi1>, %out: memref<10xi1>) func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () + "lmhlo.xor"(%lhs, %rhs, %out) : (memref<10xf32>, memref<10xf32>, memref<10xf32>) -> () return } @@ -364,7 +364,7 @@ func @xor_memref(%lhs: memref<10xf32>, %rhs: memref<10xf32>, %out: memref<10xf32 // CHECK-LABEL: func @broadcast_in_dim_memref func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) -> () { - "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () + "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (memref<1x2xi32>, memref<1x2x2xi32>) -> () return } @@ -372,7 +372,7 @@ func @broadcast_in_dim_memref(%arg0: memref<1x2xi32>, %out: memref<1x2x2xi32>) - // CHECK-LABEL: func @broadcast_in_dim_zero_rank_memref func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi32>) -> () { - "xla_lhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () + "lmhlo.broadcast_in_dim"(%arg0, %out) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (memref, memref<1x2x3xi32>) -> () return } @@ -381,10 +381,10 @@ func @broadcast_in_dim_zero_rank_memref(%arg0: memref, %out: memref<1x2x3xi // CHECK-LABEL: func @reduce_memref func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf32>) -> () { - "xla_lhlo.reduce"(%input, %init, %out) ( { + "lmhlo.reduce"(%input, %init, %out) ( { ^bb0(%arg1: memref, %arg2: memref, %result: memref): - "xla_lhlo.add"(%arg1, %arg2, %result) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.add"(%arg1, %arg2, %result) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () } ) {dimensions = dense<[0]> : tensor<1xi64>} : (memref<10xf32>, memref, memref<1xf32>) -> () return } @@ -393,14 +393,14 @@ func @reduce_memref(%input: memref<10xf32>, %init: memref, %out: memref<1xf // CHECK-LABEL: func @fusion_memref func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: memref<10xf32>, %out: memref<10xf32>) -> () { - "xla_lhlo.fusion"() ( { + "lmhlo.fusion"() ( { %0 = tensor_load %input1 : memref<10xf32> %1 = tensor_load %input2 : memref<10xf32> %2 = "mhlo.add"(%0, %1) {name = "add"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> %3 = tensor_load %input3 : memref<10xf32> %4 = "mhlo.multiply"(%2, %3) {name = "multiply"} : (tensor<10xf32>, tensor<10xf32>) -> tensor<10xf32> tensor_store %4, %out : memref<10xf32> - "xla_lhlo.terminator"() : () -> () + "lmhlo.terminator"() : () -> () } ) : () -> () return } @@ -409,18 +409,18 @@ func @fusion_memref(%input1: memref<10xf32>, %input2: memref<10xf32>, %input3: m // CHECK-LABEL: func @case_memref func @case_memref(%index: memref, %operand_1: memref, %operand_2: memref, %operand_3: memref, %out: memref) -> () { - "xla_lhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { + "lmhlo.case"(%index, %operand_1, %operand_2, %operand_3, %out) ( { ^bb0(%arg0: memref): - "xla_lhlo.negate"(%arg0, %out) : (memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.negate"(%arg0, %out) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () }, { ^bb0(%arg0: memref): - "xla_lhlo.copy"(%arg0, %out) : (memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.copy"(%arg0, %out) : (memref, memref) -> () + "lmhlo.terminator"() : () -> () }, { ^bb0(%arg0: memref): - "xla_lhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () - "xla_lhlo.terminator"() : () -> () + "lmhlo.add"(%arg0, %arg0, %out) : (memref, memref, memref) -> () + "lmhlo.terminator"() : () -> () } ) {operand_segment_sizes = dense<[1, 3, 1]> : vector<3xi32>} : (memref, memref, memref, memref, memref) -> () @@ -430,7 +430,7 @@ func @case_memref(%index: memref, %operand_1: memref, %operand_2: memr // ----- func @static_memref_cast(%in: memref<10x1xf32>) { - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x1xf32> -> memref<10xf32, offset: 0, strides: [1]> return } @@ -440,7 +440,7 @@ func @static_memref_cast(%in: memref<10x1xf32>) { func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { // expected-error @+1 {{operand must have static shape}} - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x?xf32> -> memref<10x1xf32, offset: 0, strides: [10, 1]> return } @@ -449,7 +449,7 @@ func @static_memref_cast_dynamic_operand(%in: memref<10x?xf32>) { func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { // expected-error @+1 {{result must have static shape}} - %out = xla_lhlo.static_memref_cast %in + %out = lmhlo.static_memref_cast %in : memref<10x1xf32> -> memref<10x?xf32, offset: 0, strides: [?, ?]> return } @@ -459,7 +459,7 @@ func @static_memref_cast_dynamic_result(%in: memref<10x1xf32>) { func @dynamic_memref_cast(%in: memref) { %size = constant 10 : index %step = constant 1 : index - %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + %out = lmhlo.dynamic_memref_cast %in(%size)[%step] : memref -> memref return } @@ -471,7 +471,7 @@ func @dynamic_memref_cast_incompatible_result_type(%in: memref) { // expected-error @+3 {{`sizes` args count must be equal to the rank of the output memref}} %size = constant 10 : index %step = constant 1 : index - %out = xla_lhlo.dynamic_memref_cast %in(%size)[%step] + %out = lmhlo.dynamic_memref_cast %in(%size)[%step] : memref -> memref return } @@ -483,19 +483,19 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, // CHECK-SAME: [[UNRANKED:%.*]]: memref<*xf32>, [[SHAPE_1:%.*]]: memref<1xi32>, // CHECK-SAME: [[SHAPE_2:%.*]]: memref<2xi32>, [[SHAPE_3:%.*]]: memref - // CHECK-NEXT: [[DYN_VEC:%.*]] = xla_lhlo.reshape_memref_cast [[UNRANKED]] + // CHECK-NEXT: [[DYN_VEC:%.*]] = lmhlo.reshape_memref_cast [[UNRANKED]] // CHECK-SAME: : (memref<*xf32>, memref<1xi32>) -> memref - %dyn_vec = xla_lhlo.reshape_memref_cast %unranked(%shape1) + %dyn_vec = lmhlo.reshape_memref_cast %unranked(%shape1) : (memref<*xf32>, memref<1xi32>) -> memref - // CHECK-NEXT: [[DYN_MAT:%.*]] = xla_lhlo.reshape_memref_cast [[DYN_VEC]] + // CHECK-NEXT: [[DYN_MAT:%.*]] = lmhlo.reshape_memref_cast [[DYN_VEC]] // CHECK-SAME: : (memref, memref<2xi32>) -> memref - %dyn_mat = xla_lhlo.reshape_memref_cast %dyn_vec(%shape2) + %dyn_mat = lmhlo.reshape_memref_cast %dyn_vec(%shape2) : (memref, memref<2xi32>) -> memref - // CHECK-NEXT: {{%.*}} = xla_lhlo.reshape_memref_cast [[DYN_MAT]] + // CHECK-NEXT: {{%.*}} = lmhlo.reshape_memref_cast [[DYN_MAT]] // CHECK-SAME: : (memref, memref) -> memref<*xf32> - %new_unranked = xla_lhlo.reshape_memref_cast %dyn_mat(%shape3) + %new_unranked = lmhlo.reshape_memref_cast %dyn_mat(%shape3) : (memref, memref) -> memref<*xf32> return } @@ -505,7 +505,7 @@ func @reshape_memref_cast(%unranked: memref<*xf32>, %shape1: memref<1xi32>, func @reshape_memref_cast_element_type_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{element types of source and destination memref types should be the same}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref } @@ -514,7 +514,7 @@ func @reshape_memref_cast_element_type_mismatch( func @reshape_memref_cast_dst_ranked_shape_unranked( %buf: memref<*xf32>, %shape: memref) { // expected-error @+1 {{cannot use shape operand with dynamic length to cast statically-ranked memref type}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref) -> memref return } @@ -524,7 +524,7 @@ func @reshape_memref_cast_dst_ranked_shape_unranked( func @reshape_memref_cast_dst_shape_rank_mismatch( %buf: memref<*xf32>, %shape: memref<1xi32>) { // expected-error @+1 {{length of shape operand differs from the result's memref rank}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<*xf32>, memref<1xi32>) -> memref return } @@ -535,7 +535,7 @@ func @reshape_memref_cast_affine_map_is_not_identity( %buf: memref<4x4xf32, offset: 0, strides: [3, 2]>, %shape: memref<1xi32>) { // expected-error @+1 {{operand memref type should have identity affine map}} - xla_lhlo.reshape_memref_cast %buf(%shape) + lmhlo.reshape_memref_cast %buf(%shape) : (memref<4x4xf32, offset: 0, strides: [3, 2]>, memref<1xi32>) -> memref<8xf32> return @@ -545,7 +545,7 @@ func @reshape_memref_cast_affine_map_is_not_identity( // CHECK-LABEL: func @atan2_memrefs func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -553,7 +553,7 @@ func @atan2_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref // CHECK-LABEL: func @atan2_memrefs func @atan2_memrefs(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -561,7 +561,7 @@ func @atan2_memrefs(%arg0: memref<1xcomplex>, %arg1: memref<1xcomplex> func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.atan2"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -569,7 +569,7 @@ func @atan2_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref // CHECK-LABEL: func @bitcast_convert_memrefs func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () + "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi32>) -> () return } @@ -577,7 +577,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi32>) -> func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () + "lmhlo.bitcast_convert"(%arg0, %arg_out) : (memref<1xf32>, memref<2xi32>) -> () return } @@ -585,7 +585,7 @@ func @bitcast_convert_memrefs(%arg0: memref<1xf32>, %arg_out: memref<2xi32>) -> // CHECK-LABEL: func @clz_memrefs func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.count_leading_zeros"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -593,7 +593,7 @@ func @clz_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @expm1_memrefs func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -601,7 +601,7 @@ func @expm1_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @expm1_memrefs func @expm1_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.exponential_minus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -609,7 +609,7 @@ func @expm1_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.floor"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -617,7 +617,7 @@ func @floor_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.floor"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -625,7 +625,7 @@ func @floor_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @imag_memrefs func @imag_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + "lmhlo.imag"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () return } @@ -633,7 +633,7 @@ func @imag_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error@+1{{must be memref of complex-type values}} - "xla_lhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.imag"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -641,7 +641,7 @@ func @imag_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @real_memrefs func @real_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () + "lmhlo.real"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xf32>) -> () return } @@ -649,7 +649,7 @@ func @real_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xf32>) -> () func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error@+1{{must be memref of complex-type values}} - "xla_lhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.real"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -657,7 +657,7 @@ func @real_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @is_finite_memrefs func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { - "xla_lhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () + "lmhlo.is_finite"(%arg0, %arg_out) : (memref<1xf32>, memref<1xi1>) -> () return } @@ -665,7 +665,7 @@ func @is_finite_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xi1>) -> () { // CHECK-LABEL: func @log1p_memrefs func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -673,7 +673,7 @@ func @log1p_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @log1p_memrefs func @log1p_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex>) -> () { - "xla_lhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () + "lmhlo.log_plus_one"(%arg0, %arg_out) : (memref<1xcomplex>, memref<1xcomplex>) -> () return } @@ -681,7 +681,7 @@ func @log1p_memrefs(%arg0: memref<1xcomplex>, %arg_out: memref<1xcomplex, %out: memref<10xi32>) -> () { // expected-error@+1{{must be memref of floating-point or complex-type values}} - "xla_lhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () + "lmhlo.log_plus_one"(%in, %out) : (memref<10xi32>, memref<10xi32>) -> () return } @@ -689,7 +689,7 @@ func @log1p_memref(%in: memref<10xi32>, %out: memref<10xi32>) -> () { // CHECK-LABEL: func @not_memrefs func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -697,7 +697,7 @@ func @not_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @not_memrefs func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xi1>, memref<1xi1>) -> () return } @@ -705,7 +705,7 @@ func @not_memrefs(%arg0: memref<1xi1>, %arg_out: memref<1xi1>) -> () { func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer or pred (AKA boolean or 1-bit integer) values}} - "xla_lhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.not"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -713,7 +713,7 @@ func @not_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @popcnt_memrefs func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -721,7 +721,7 @@ func @popcnt_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.popcnt"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -729,7 +729,7 @@ func @popcnt_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // CHECK-LABEL: func @reduce_precision_memrefs func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.reduce_precision"(%arg0, %arg_out) { exponent_bits = 4 : i32, mantissa_bits = 4 : i32 } : (memref<1xf32>, memref<1xf32>) -> () return } @@ -737,7 +737,7 @@ func @reduce_precision_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> // CHECK-LABEL: func @round_memrefs func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { - "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () + "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xf32>, memref<1xf32>) -> () return } @@ -745,7 +745,7 @@ func @round_memrefs(%arg0: memref<1xf32>, %arg_out: memref<1xf32>) -> () { func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // expected-error@+1{{must be memref of floating-point values}} - "xla_lhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () + "lmhlo.round_nearest_afz"(%arg0, %arg_out) : (memref<1xi32>, memref<1xi32>) -> () return } @@ -753,7 +753,7 @@ func @round_memrefs(%arg0: memref<1xi32>, %arg_out: memref<1xi32>) -> () { // CHECK-LABEL: func @shift_left_memrefs func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -761,7 +761,7 @@ func @shift_left_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: m func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_left"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -769,7 +769,7 @@ func @shift_left_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: m // CHECK-LABEL: func @shift_right_arithmetic_memrefs func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -777,7 +777,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_right_arithmetic"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -785,7 +785,7 @@ func @shift_right_arithmetic_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, // CHECK-LABEL: func @shift_right_logical_memrefs func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %arg_out: memref<1xi32>) -> () { - "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () + "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xi32>, memref<1xi32>, memref<1xi32>) -> () return } @@ -793,7 +793,7 @@ func @shift_right_logical_memrefs(%arg0: memref<1xi32>, %arg1: memref<1xi32>, %a func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %arg_out: memref<1xf32>) -> () { // expected-error @+1 {{must be memref of 8/16/32/64-bit signless integer or 8/16/32/64-bit unsigned integer values}} - "xla_lhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () + "lmhlo.shift_right_logical"(%arg0, %arg1, %arg_out) : (memref<1xf32>, memref<1xf32>, memref<1xf32>) -> () return } @@ -801,14 +801,14 @@ func @shift_right_logical_memrefs(%arg0: memref<1xf32>, %arg1: memref<1xf32>, %a // CHECK-LABEL: func @all_reduce_memrefs func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () { - "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + "lmhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): %max = mhlo.maximum %lhs, %rhs : tensor "mhlo.return"(%max) : (tensor) -> () }) { replica_groups = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64> }: (memref<10xf32>, memref<10xf32>) -> () - "xla_lhlo.all_reduce"(%arg0, %arg_out) ({ + "lmhlo.all_reduce"(%arg0, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): %max = mhlo.maximum %lhs, %rhs : tensor "mhlo.return"(%max) : (tensor) -> () @@ -826,11 +826,11 @@ func @all_reduce_memrefs(%arg0: memref<10xf32>, %arg_out: memref<10xf32>) -> () // CHECK-LABEL: func @collective_permute_memrefs func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128x32xf32>) -> () { - "xla_lhlo.collective_permute"(%arg0, %arg_out) { + "lmhlo.collective_permute"(%arg0, %arg_out) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64> } : (memref<128x32xf32>, memref<128x32xf32>) -> () - "xla_lhlo.collective_permute"(%arg0, %arg_out) { + "lmhlo.collective_permute"(%arg0, %arg_out) { source_target_pairs = dense<[[0, 1], [1, 2], [2, 3]]> : tensor<3x2xi64>, channel_id = { handle = 5 : i64, type = 2 : i64 } } : (memref<128x32xf32>, memref<128x32xf32>) -> () @@ -841,7 +841,7 @@ func @collective_permute_memrefs(%arg0: memref<128x32xf32>, %arg_out: memref<128 // CHECK-LABEL: func @fft_memrefs func @fft_memrefs(%arg0: memref<3x9xf32>, %arg_out: memref<3x5xcomplex>) -> () { - "xla_lhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> () + "lmhlo.fft"(%arg0, %arg_out) {fft_length = dense<9> : tensor<1xi64>, fft_type = "RFFT"} : (memref<3x9xf32>, memref<3x5xcomplex>) -> () return } @@ -852,7 +852,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg3: memref<8xf32>, %arg4: memref<8x8x8x8xf32>, %grad_operand: memref<8x8x8x8xf32>, %grad_scale: memref<8xf32>, %grad_offset: memref<8xf32>) -> () { - "xla_lhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_grad"(%arg0, %arg1, %arg2, %arg3, %arg4, %grad_operand, %grad_scale, %grad_offset) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () return @@ -863,7 +863,7 @@ func @batch_norm_grad_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, // CHECK-LABEL: func @batch_norm_inference_memrefs func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, %arg3: memref<8xf32>, %arg4: memref<8xf32>, %arg_out: memref<8x8x8x8xf32>) -> () { - "xla_lhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_inference"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg_out) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>) -> () return } @@ -874,7 +874,7 @@ func @batch_norm_inference_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf32>, %arg2: memref<8xf32>, %output: memref<8x8x8x8xf32>, %batch_mean: memref<8xf32>, %batch_var: memref<8xf32>) -> () { - "xla_lhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} + "lmhlo.batch_norm_training"(%arg0, %arg1, %arg2, %output, %batch_mean, %batch_var) {epsilon = 1.000000e-03 : f32, feature_index = 3 : i64} : (memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>, memref<8x8x8x8xf32>, memref<8xf32>, memref<8xf32>) -> () return } @@ -883,8 +883,8 @@ func @batch_norm_training_memrefs(%arg0: memref<8x8x8x8xf32>, %arg1: memref<8xf3 // CHECK-LABEL: func @cholesky_memrefs func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291xf32>) -> () { - "xla_lhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () - "xla_lhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + "lmhlo.cholesky"(%arg0, %arg_out) : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () + "lmhlo.cholesky"(%arg0, %arg_out) { lower = true } : (memref<1x291x291xf32>, memref<1x291x291xf32>) -> () return } @@ -892,7 +892,7 @@ func @cholesky_memrefs(%arg0: memref<1x291x291xf32>, %arg_out: memref<1x291x291x // CHECK-LABEL: func @infeed_memrefs func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { - "xla_lhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () + "lmhlo.infeed"(%arg_out) { config = "x" } : (memref<3xf32>) -> () return } @@ -900,7 +900,7 @@ func @infeed_memrefs(%arg_out: memref<3xf32>) -> () { // CHECK-LABEL: func @outfeed_memrefs func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { - "xla_lhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () + "lmhlo.outfeed"(%arg0) { config = "x" } : (memref<3xf32>) -> () return } @@ -908,7 +908,7 @@ func @outfeed_memrefs(%arg0: memref<3xf32>) -> () { // CHECK-LABEL: func @replica_id_memrefs func @replica_id_memrefs(%arg_out: memref) -> () { - "xla_lhlo.replica_id"(%arg_out) : (memref) -> () + "lmhlo.replica_id"(%arg_out) : (memref) -> () return } @@ -916,7 +916,7 @@ func @replica_id_memrefs(%arg_out: memref) -> () { // CHECK-LABEL: func @triangular_solve_memrefs func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, %arg_out: memref<3x4xf32>) -> () { - "xla_lhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} + "lmhlo.triangular_solve"(%arg0, %arg1, %arg_out) {left_side = true, lower = true, transpose_a = "NO_TRANSPOSE", unit_diagonal = true} : (memref<4x4xf32>, memref<3x4xf32>, memref<3x4xf32>) -> () return } @@ -925,9 +925,9 @@ func @triangular_solve_memrefs(%arg0: memref<4x4xf32>, %arg1: memref<3x4xf32>, % // CHECK-LABEL: func @while_memrefs func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { - "xla_lhlo.while"(%arg0, %arg_out) ( - { ^bb0(%arg: memref, %cond: memref): "xla_lhlo.terminator"() : () -> () }, - { ^bb0(%arg: memref, %body_out: memref): "xla_lhlo.terminator"() : () -> () } + "lmhlo.while"(%arg0, %arg_out) ( + { ^bb0(%arg: memref, %cond: memref): "lmhlo.terminator"() : () -> () }, + { ^bb0(%arg: memref, %body_out: memref): "lmhlo.terminator"() : () -> () } ) : (memref, memref) -> () return } @@ -936,9 +936,9 @@ func @while_memrefs(%arg0: memref, %arg_out: memref) -> () { // CHECK-LABEL: func @while_memrefs func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref, %arg1_out: memref<5xf32>) -> () { - "xla_lhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "xla_lhlo.terminator"() : () -> () }, - { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "xla_lhlo.terminator"() : () -> () } + "lmhlo.while"(%arg0, %arg1, %arg0_out, %arg1_out) ( + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %cond: memref): "lmhlo.terminator"() : () -> () }, + { ^bb0(%cur0: memref, %cur1: memref<5xf32>, %body_out0: memref, %body_out1: memref<5xf32>): "lmhlo.terminator"() : () -> () } ) : (memref, memref<5xf32>, memref, memref<5xf32>) -> () return } @@ -947,7 +947,7 @@ func @while_memrefs(%arg0: memref, %arg1: memref<5xf32>, %arg0_out: memref< // CHECK-LABEL: func @bitcast_memrefs func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { - "xla_lhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () + "lmhlo.bitcast"(%arg0, %arg_out) : (memref<1xf64>, memref<2xi32>) -> () return } @@ -956,7 +956,7 @@ func @bitcast_memrefs(%arg0: memref<1xf64>, %arg_out: memref<2xi32>) -> () { // CHECK-LABEL: func @scatter_memrefs func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32>, %updates: memref<10x300xf32>, %arg_out: memref<200x100x300xf32>) -> () { - "xla_lhlo.scatter" (%input, %indices, %updates, %arg_out) ({ + "lmhlo.scatter" (%input, %indices, %updates, %arg_out) ({ ^bb0(%lhs: tensor, %rhs: tensor): // no predecessors %add = mhlo.add %lhs, %rhs : tensor "mhlo.return"(%add) : (tensor) -> () @@ -977,7 +977,7 @@ func @scatter_memrefs(%input: memref<200x100x300xf32>, %indices: memref<10x2xi32 // CHECK-LABEL: func @map_memrefs func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<20xf32>) -> () { - "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + "lmhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): %c = mhlo.add %a, %b : tensor "mhlo.return"(%c) : (tensor) -> () @@ -989,7 +989,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref<10xf32>) -> () { // expected-error@+1{{requires the same shape for all operands}} - "xla_lhlo.map"(%arg0, %arg1, %arg_out) ({ + "lmhlo.map"(%arg0, %arg1, %arg_out) ({ ^bb0(%a: tensor, %b: tensor): %c = mhlo.add %a, %b : tensor "mhlo.return"(%c) : (tensor) -> () @@ -1001,7 +1001,7 @@ func @map_memrefs(%arg0: memref<20xf32>, %arg1: memref<20xf32>, %arg_out: memref // CHECK-LABEL: func @rng_get_and_update_state_memrefs func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { - "xla_lhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () + "lmhlo.rng_get_and_update_state"(%state) { delta = 1 : i64 } : (memref<1xui64>) -> () return } @@ -1010,7 +1010,7 @@ func @rng_get_and_update_state_memrefs(%state: memref<1xui64>) -> () { // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () @@ -1023,7 +1023,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () @@ -1036,7 +1036,7 @@ func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, // CHECK-LABEL: func @sort_memrefs func @sort_memrefs(%arg0: memref<16x16xf32>, %arg1: memref<16x16xf16>, %out0: memref<16x16xf32>, %out1: memref<16x16xf16>) -> () { - "xla_lhlo.sort"(%arg0, %arg1, %out0, %out1) ( { + "lmhlo.sort"(%arg0, %arg1, %out0, %out1) ( { ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): %7 = "mhlo.compare"(%a, %b) {comparison_direction = "GT"} : (tensor, tensor) -> tensor "mhlo.return"(%7) : (tensor) -> () diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index ea62ce1f3fd..348dec47ad2 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -128,7 +128,7 @@ Status LowerTfOpToLhloWithDynamicShapes(mlir::ModuleOp module) { pm.addNestedPass(absl::make_unique()); pm.addPass(mlir::mhlo::createLegalizeToLhloPass( /*results_escape_functions=*/true)); - pm.addNestedPass(mlir::xla_lhlo::createLhloCopyRemovalPass()); + pm.addNestedPass(mlir::lmhlo::createLhloCopyRemovalPass()); if (failed(pm.run(module))) { return InternalError("Lowering TF to LHLO failed."); diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir index 40ba48b3779..09a85177fae 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/ops.mlir @@ -1,11 +1,11 @@ // RUN: xla-opt -split-input-file -xla-hlo-to-lhlo-with-xla %s | FileCheck --enable-var-scope %s // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.abs +// CHECK: lmhlo.abs // CHECK-SAME: %[[ARG0]], %[[VIEW]] %abs = "mhlo.abs"(%value) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %abs : tensor<2x2xf32> @@ -14,12 +14,12 @@ func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.add +// CHECK: lmhlo.add // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.add"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> @@ -29,12 +29,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> -// CHECK: lhlo.and +// CHECK: lmhlo.and // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.and"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> @@ -44,11 +44,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.ceil +// CHECK: lmhlo.ceil // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.ceil"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -57,12 +57,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<1x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcomplex> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> -// CHECK: lhlo.complex +// CHECK: lmhlo.complex // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.complex"(%value0, %value1) : (tensor<1x2xf32>, tensor<1x2xf32>) -> (tensor<1x2xcomplex>) @@ -72,11 +72,11 @@ func @main(%value0: tensor<1x2xf32>, %value1: tensor<1x2xf32>) -> tensor<1x2xcom // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<1x2xcomplex> -// CHECK: lhlo.cosine +// CHECK: lmhlo.cosine // CHECK-SAME: %[[ARG0]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.cosine"(%value0) : (tensor<1x2xcomplex>) -> tensor<1x2xcomplex> @@ -86,12 +86,12 @@ func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xcomplex> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.divide +// CHECK: lmhlo.divide // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.divide"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> @@ -101,11 +101,11 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.exponential +// CHECK: lmhlo.exponential // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.exponential"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -114,11 +114,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.log +// CHECK: lmhlo.log // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.log"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -127,12 +127,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.maximum +// CHECK: lmhlo.maximum // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.maximum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> @@ -142,12 +142,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.minimum +// CHECK: lmhlo.minimum // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.minimum"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> @@ -157,12 +157,12 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.multiply +// CHECK: lmhlo.multiply // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.multiply"(%value0, %value1) : (tensor<2x2xf32>, tensor<2x2xf32>) -> tensor<2x2xf32> @@ -172,11 +172,11 @@ func @main(%value0: tensor<2x2xf32>, %value1: tensor<2x2xf32>) -> tensor<2x2xf32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.negate +// CHECK: lmhlo.negate // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.negate"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -185,11 +185,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> -// CHECK: lhlo.real +// CHECK: lmhlo.real // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.real"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) return %res : tensor<1x2xf32> @@ -198,11 +198,11 @@ func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<1x2xcomplex> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<8xi8> func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<8xi8> to memref<1x2xf32> -// CHECK: lhlo.imag +// CHECK: lmhlo.imag // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.imag"(%value0) : (tensor<1x2xcomplex>) -> (tensor<1x2xf32>) return %res : tensor<1x2xf32> @@ -211,12 +211,12 @@ func @main(%value0: tensor<1x2xcomplex>) -> tensor<1x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> -// CHECK: lhlo.remainder +// CHECK: lmhlo.remainder // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.remainder"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> @@ -226,11 +226,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.rsqrt +// CHECK: lmhlo.rsqrt // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.rsqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -239,13 +239,13 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {xla_lhlo.params = 1 -// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {xla_lhlo.params = 2 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi1> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xf32> {lmhlo.params = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<2x2xf32> {lmhlo.params = 2 // CHECK-SAME: %[[ARG3:.*]]: memref<16xi8> func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.select +// CHECK: lmhlo.select // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[ARG2]], %[[VIEW]] // CHECK-NEXT: return %0 = "mhlo.select"(%pred, %lhs, %rhs) : (tensor<2x2xi1>, tensor<2x2xf32>, tensor<2x2xf32>) -> (tensor<2x2xf32>) @@ -255,11 +255,11 @@ func @main(%pred: tensor<2x2xi1>, %lhs: tensor<2x2xf32>, %rhs: tensor<2x2xf32>) // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.sign +// CHECK: lmhlo.sign // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.sign"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -268,11 +268,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.sqrt +// CHECK: lmhlo.sqrt // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.sqrt"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -281,12 +281,12 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {xla_lhlo.params = 0 -// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {xla_lhlo.params = 1 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xi32> {lmhlo.params = 0 +// CHECK-SAME: %[[ARG1:.*]]: memref<2x2xi32> {lmhlo.params = 1 // CHECK-SAME: %[[ARG2:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xi32> -// CHECK: lhlo.subtract +// CHECK: lmhlo.subtract // CHECK-SAME: %[[ARG0]], %[[ARG1]], %[[VIEW]] // CHECK-NEXT: return %res = "mhlo.subtract"(%value0, %value1) : (tensor<2x2xi32>, tensor<2x2xi32>) -> tensor<2x2xi32> @@ -296,11 +296,11 @@ func @main(%value0: tensor<2x2xi32>, %value1: tensor<2x2xi32>) -> tensor<2x2xi32 // ----- // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 // CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK: %[[VIEW:.*]] = {{.*}} memref<16xi8> to memref<2x2xf32> -// CHECK: lhlo.tanh +// CHECK: lmhlo.tanh // CHECK-SAME: %[[ARG0]], %[[VIEW]] %res = "mhlo.tanh"(%value0) : (tensor<2x2xf32>) -> tensor<2x2xf32> return %res : tensor<2x2xf32> @@ -311,11 +311,11 @@ func @main(%value0: tensor<2x2xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @main // CHECK-SAME: %[[ARG0:.*]]: memref<5x5xi32> // CHECK-SAME: %[[ARG1:.*]]: memref<5x5xf32> -// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {xla_lhlo.alloc = 0 -// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {xla_lhlo.alloc = 1 +// CHECK-SAME: %[[ARG2:.*]]: memref<100xi8> {lmhlo.alloc = 0 +// CHECK-SAME: %[[ARG3:.*]]: memref<100xi8> {lmhlo.alloc = 1 // CHECK: %[[VIEW0:.*]] = std.view %[[ARG2]]{{.*}} : memref<100xi8> to memref<5x5xi32> // CHECK: %[[VIEW1:.*]] = std.view %[[ARG3]]{{.*}} : memref<100xi8> to memref<5x5xf32> -// CHECK: "xla_lhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]]) +// CHECK: "lmhlo.sort"(%[[ARG0]], %[[ARG1]], %[[VIEW0]], %[[VIEW1]]) func @main(%key: tensor<5x5xi32>, %value: tensor<5x5xf32>) -> tuple, tensor<5x5xf32>> { %res = "mhlo.sort"(%key, %value) ({ ^bb0(%a: tensor, %b: tensor, %c: tensor, %d: tensor): diff --git a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir index d442319e7b2..cc07624d63d 100644 --- a/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir +++ b/tensorflow/compiler/mlir/xla/tests/hlo_to_lhlo_with_xla/passthrough.mlir @@ -3,14 +3,14 @@ // Current allocation will lead to one buffer argument for the "value" and // another one for the output, an no returned values. // CHECK-LABEL: func @main -// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {xla_lhlo.params = 0 : index}, -// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {xla_lhlo.alloc = 0 : index, xla_lhlo.liveout = true} +// CHECK-SAME: %[[ARG0:.*]]: memref<2x2xf32> {lmhlo.params = 0 : index}, +// CHECK-SAME: %[[ARG1:.*]]: memref<16xi8> {lmhlo.alloc = 0 : index, lmhlo.liveout = true} // CHECK-SAME: ) { func @main(%value: tensor<2x2xf32>) -> tensor<2x2xf32> { // The only expected instruction is a copy from the input into the output. // CHECK: %[[C0:.*]] = constant 0 : index // CHECK: %[[OUTPUT:.*]] = std.view %[[ARG1]][%[[C0]]][] : memref<16xi8> to memref<2x2xf32> - // CHECK: xla_lhlo.copy + // CHECK: lmhlo.copy // CHECK-SAME: %[[ARG0]], %[[OUTPUT]] return %value : tensor<2x2xf32> } diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc index c14f6803961..519068893e7 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.cc @@ -190,51 +190,51 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { using ::xla::HloOpcode; switch (instr->opcode()) { case HloOpcode::kAbs: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kAdd: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kAnd: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kCeil: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kComplex: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kCopy: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kCos: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kDivide: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kExp: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kImag: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kLog: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kMaximum: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kMinimum: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kMultiply: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kNegate: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kReal: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kRemainder: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kRsqrt: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSelect: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSign: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSqrt: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kSubtract: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); case HloOpcode::kTanh: - return CreateOpWithoutAttrs(instr).status(); + return CreateOpWithoutAttrs(instr).status(); default: llvm::errs() << instr->ToString(); return tensorflow::errors::Internal( @@ -246,7 +246,7 @@ Status LhloDialectEmitter::DefaultAction(HloInstruction* instr) { StatusOr LhloDialectEmitter::EmitSortOp( HloInstruction* instr) { - TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); + TF_ASSIGN_OR_RETURN(auto sort, CreateOpWithoutAttrs(instr)); auto* sort_instr = ::xla::Cast<::xla::HloSortInstruction>(instr); sort.dimensionAttr(builder_.getI64IntegerAttr(sort_instr->sort_dimension())); sort.is_stableAttr(builder_.getBoolAttr(sort_instr->is_stable())); @@ -379,16 +379,16 @@ Status LhloDialectEmitter::Initialize() { block->addArgument(arg_type); allocations_[alloc] = block->getArguments().back(); args_attrs.emplace_back(); - args_attrs.back().set(builder_.getIdentifier("xla_lhlo.params"), + args_attrs.back().set(builder_.getIdentifier("lmhlo.params"), builder_.getIndexAttr(alloc->parameter_number())); } else { block->addArgument(MemRefType::get({alloc->size()}, i8_type_)); allocations_[alloc] = block->getArguments().back(); args_attrs.emplace_back(); - args_attrs.back().set(builder_.getIdentifier("xla_lhlo.alloc"), + args_attrs.back().set(builder_.getIdentifier("lmhlo.alloc"), builder_.getIndexAttr(alloc->index())); if (alloc->maybe_live_out()) - args_attrs.back().set(builder_.getIdentifier("xla_lhlo.liveout"), + args_attrs.back().set(builder_.getIdentifier("lmhlo.liveout"), builder_.getBoolAttr(true)); } } diff --git a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h index bea77ecdfe1..ca40eb5804c 100644 --- a/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h +++ b/tensorflow/compiler/mlir/xla/transforms/mhlo_to_lhlo_with_xla.h @@ -58,7 +58,7 @@ class LhloDialectEmitter : public ::xla::DfsHloVisitorWithDefault { tensorflow::Status HandleSort(::xla::HloInstruction* instr) final; // Helper function that recursively visits the tuple structure in - // `current_shape`, and reconstruct a matching xla_lhlo::TupleOp. + // `current_shape`, and reconstruct a matching lmhlo::TupleOp. // Each leaf node is converted to an std.view op with corresponding offsets. // If no tuple presents, it simply returns a view of the buffer. tensorflow::Status CreateView(const ::xla::HloInstruction* instr, diff --git a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc index 9f81db551e2..5cb97dcd1d1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/kernel_lowering.cc @@ -60,7 +60,7 @@ namespace xla { namespace mlir_gpu { namespace { -using ::mlir::xla_lhlo::FusionOp; +using ::mlir::lmhlo::FusionOp; // Replaces a FusionOp by the operations contained in its region. struct FusionOpRemover @@ -463,14 +463,14 @@ Status LowerLHLOToGPU(mlir::ModuleOp module, LowerLHLOToGPUOptions options) { // Next, we can strip the outer fusion operation. pm.addPass(absl::make_unique()); // Remove unnecessary LHLO copies. - pm.addPass(::mlir::xla_lhlo::createLhloCopyRemovalPass()); + pm.addPass(::mlir::lmhlo::createLhloCopyRemovalPass()); // Transform LHLO operations to LinAlg. - pm.addPass(::mlir::xla_lhlo::createLegalizeLhloToLinalgPass()); + pm.addPass(::mlir::lmhlo::createLegalizeLhloToLinalgPass()); // Fuse linalg operations. - pm.addPass(::mlir::xla_lhlo::createLhloFuseLinalg(/*use_parallel_loops=*/true, - tiling_for_unrolling)); + pm.addPass(::mlir::lmhlo::createLhloFuseLinalg(/*use_parallel_loops=*/true, + tiling_for_unrolling)); // Legalize reduce operations directly to GPU dialect. - pm.addPass(::mlir::xla_lhlo::createLegalizeToGpuPass()); + pm.addPass(::mlir::lmhlo::createLegalizeToGpuPass()); // Transform the Linalg operations inside of the loop nest into parallel // loops. pm.addPass(::mlir::createConvertLinalgToParallelLoopsPass()); diff --git a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc index 3654271da53..2cf480bd4fc 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc +++ b/tensorflow/compiler/xla/service/mlir_gpu/lhlo_dialect_emitter.cc @@ -58,7 +58,7 @@ using ::xla::gpu::Thunk; using ::xla::gpu::ThunkEmitter; using ::xla::gpu::ThunkSequence; -namespace lhlo = ::mlir::xla_lhlo; +namespace lhlo = ::mlir::lmhlo; // TODO(b/137624192) Use tablegen for this. Status InsertMlirOp(HloOpcode opcode, OpBuilder func_builder, Location loc, diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo index 0927a6dc15d..ba29b0a17fd 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/abs.hlo @@ -6,5 +6,5 @@ ENTRY %Abs (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @abs(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.abs"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo index d8c20cfdab0..37c163eb83e 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add.hlo @@ -8,5 +8,5 @@ ENTRY %Add (x: f32[2,2], y: f32[2,2]) -> f32[2,2] { } // CHECK: func @add(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.add"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo index 70a5bb12f23..2603b925c76 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_multiply.hlo @@ -10,13 +10,13 @@ ENTRY %AddMultiply (x: f32[2,2], y: f32[2,2], z: f32[2,2]) -> f32[2,2] { } // CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]], %[[RESULT:.*]]: [[TYPE]]) -// CHECK: "xla_lhlo.fusion"() ( { +// CHECK: "lmhlo.fusion"() ( { // CHECK: %[[REF0:.*]] = tensor_load %[[ARG0]] : [[TYPE]] // CHECK: %[[REF1:.*]] = tensor_load %[[ARG1]] : [[TYPE]] // CHECK: %[[REF2:.*]] = tensor_load %[[ARG2]] : [[TYPE]] // CHECK: %[[ADD:.*]] = mhlo.add %[[REF1]], %[[REF2]] // CHECK: %[[MUL:.*]] = mhlo.multiply %[[ADD]], %[[REF0]] // CHECK: tensor_store %[[MUL]], %[[RESULT]] -// CHECK: "xla_lhlo.terminator"() +// CHECK: "lmhlo.terminator"() // CHECK-NEXT: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo index cf236c92ce1..a57f427cedc 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/add_reduce.hlo @@ -14,11 +14,11 @@ ENTRY %AddReduce (x: f32[100,10], c: f32[]) -> f32[100] { } // CHECK: func @reduce(%[[ARG:.*]]: [[ARGT:.*]], %[[CST:.*]]: memref, %[[RES:.*]]: [[REST:.*]]) { -// CHECK: "xla_lhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( { +// CHECK: "lmhlo.reduce"(%[[ARG]], %[[CST]], %[[RES]]) ( { // CHECK: ^bb0(%[[FARG0:.*]]: memref, %[[FARG1:.*]]: memref, %[[FRES:.*]]: memref): // CHECK: %[[LHS:.*]] = tensor_load %[[FARG0]] : memref // CHECK: %[[RHS:.*]] = tensor_load %[[FARG1]] : memref // CHECK: %[[RES:.*]] = mhlo.add %[[LHS]], %[[RHS]] : tensor // CHECK: tensor_store %[[RES]], %[[FRES]] : memref -// CHECK: "xla_lhlo.terminator"() : () -> () +// CHECK: "lmhlo.terminator"() : () -> () // CHECK-NEXT: }) {dimensions = dense<1> : tensor<1xi64>} : ([[ARGT]], memref, [[REST]]) -> () diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo index 9a2736c019a..366545c431f 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broadcast.hlo @@ -7,7 +7,7 @@ ENTRY %Broadcast (x: f32[10]) -> f32[10, 5] { } // CHECK: func @broadcast(%[[IN:.*]]: [[IN_T:.*]], %[[OUT:.*]]: [[OUT_T:.*]]) { -// CHECK: "xla_lhlo.broadcast_in_dim"(%[[IN]], %[[OUT]]) +// CHECK: "lmhlo.broadcast_in_dim"(%[[IN]], %[[OUT]]) // CHECK: {broadcast_dimensions = dense<0> : tensor<1xi64>} // CHECK: : ([[IN_T]], [[OUT_T]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo index 71014e17db8..6bbddb61a74 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/broken_add.hlo @@ -7,4 +7,4 @@ ENTRY %Add (x: f32[2,2,2], y: f32[2,2,2]) -> f32[2,2,2] { ROOT %add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y) } -// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: xla_lhlo.add; failed for testing: std.return] +// CHECK: ERRORS FOUND: [%add = f32[2,2,2]{2,1,0} add(f32[2,2,2]{2,1,0} %x, f32[2,2,2]{2,1,0} %y): failed for testing: lmhlo.add; failed for testing: std.return] diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo index 26a4131617e..f45fa1a55e2 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/ceil.hlo @@ -6,5 +6,5 @@ ENTRY %Ceil (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @ceil(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.ceil"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo index 99662951456..2a34f494083 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/compare.hlo @@ -8,6 +8,6 @@ ENTRY %Compare (x: f32[2,2], y: f32[2,2]) -> pred[2,2] { } // CHECK: func @compare(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[PRED:.*]]: [[PRED_TYPE:.*]]) { -// CHECK: "xla_lhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]]) +// CHECK: "lmhlo.compare"(%[[ARG0]], %[[ARG1]], %[[PRED]]) // CHECK: {comparison_direction = "EQ"} : ([[TYPE]], [[TYPE]], [[PRED_TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo index 996ca0b2786..99a4872b228 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/complex.hlo @@ -8,5 +8,5 @@ ENTRY %Complex (real: f32[2,2]{0,1}, imag: f32[2,2]{0,1}) -> c64[2,2] { } // CHECK: func @complex(%[[REAL:.*]]: [[BUF_F32:.*]], %[[IMAG:.*]]: [[BUF_F32]], %[[OUT:.*]]: [[BUF_C64:.*]]) { -// CHECK: "xla_lhlo.complex"(%[[REAL]], %[[IMAG]], %[[OUT]]) : ([[BUF_F32]], [[BUF_F32]], [[BUF_C64]]) -> () +// CHECK: "lmhlo.complex"(%[[REAL]], %[[IMAG]], %[[OUT]]) : ([[BUF_F32]], [[BUF_F32]], [[BUF_C64]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo index 0b858842db7..06f29185aa1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/concatenate.hlo @@ -8,6 +8,6 @@ ENTRY %Concatenate (x: f32[2,3], y: f32[2,2]) -> f32[2,5] { } // CHECK: func @concatenate(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) { -// CHECK: "xla_lhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) +// CHECK: "lmhlo.concatenate"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) // CHECK: {dimension = 1 : i64} : ([[TYPE0]], [[TYPE1]], [[RTYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo index 632a44a79e7..dd1f75b4192 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/const.hlo @@ -7,6 +7,6 @@ ENTRY %Const () -> s32[100] { } // CHECK: func @constant(%[[ARG0:.*]]: memref) -// CHECK: "xla_lhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor} +// CHECK: "lmhlo.constant"(%[[ARG0]]) {value = dense<10> : tensor} // CHECK: func @broadcast(%[[ARG1:.*]]: memref, %[[ARG2:.*]]: memref<100xi32>) -// CHECK: "xla_lhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} +// CHECK: "lmhlo.broadcast_in_dim"(%[[ARG1]], %[[ARG2]]) {broadcast_dimensions = dense<[]> : tensor<0xi64>} diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo index cc1acd03ad5..b4058da8019 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy.hlo @@ -7,4 +7,4 @@ ENTRY %Copy (x: f32[2,4]) -> f32[2,4] { } // CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, %[[RESULT:.*]]: memref<2x4xf32>) { -// CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> () +// CHECK: "lmhlo.copy"(%[[OPERAND]], %[[RESULT]]) : (memref<2x4xf32>, memref<2x4xf32>) -> () diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo index 7a9b994eae6..3a3dd22b338 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/copy_transpose.hlo @@ -9,5 +9,5 @@ ENTRY %CopyTranspose (x: f32[2,4]) -> f32[2,4]{0,1} { // CHECK: #[[MAP0:.*]] = affine_map<(d0, d1) -> (d0 * 2 + d1)> // CHECK: func @copy(%[[OPERAND:.*]]: memref<2x4xf32>, // CHECK-SAME: %[[RESULT:.*]]: memref<2x4xf32, #[[MAP0]]>) -// CHECK: "xla_lhlo.copy"(%[[OPERAND]], %[[RESULT]]) +// CHECK: "lmhlo.copy"(%[[OPERAND]], %[[RESULT]]) // CHECK-SAME: : (memref<2x4xf32>, memref<2x4xf32, #[[MAP0]]>) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo index 12c9c16d689..8a00a56206c 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/cos.hlo @@ -6,5 +6,5 @@ ENTRY %Cos (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @cosine(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.cosine"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.cosine"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo index 741ebe1118e..42cc605b2b6 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/exp.hlo @@ -7,6 +7,6 @@ ENTRY %Exp (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @exponential(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.exponential"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.exponential"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo index 2dbfb038081..f74cdef1473 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/fused_reduce.hlo @@ -21,7 +21,7 @@ ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { } // CHECK: func @fusion(%[[ARG0:.*]]: [[TYPE:.*]], %[[RESULT:.*]]: [[RTYPE:.*]]) -// CHECK: "xla_lhlo.fusion"() ( { +// CHECK: "lmhlo.fusion"() ( { // CHECK: %[[REF0:.*]] = tensor_load %arg0 : [[TYPE]] // CHECK: %[[CT0:.*]] = mhlo.constant dense<0.000000e+00> // CHECK: %[[RED:.*]] = "mhlo.reduce"(%0, %1) ( { @@ -30,6 +30,6 @@ ENTRY %FusedReduce (x: f32[100,10]) -> f32[10] { // CHECK: "mhlo.return"(%[[ADD]]) // CHECK: }) // CHECK: tensor_store %[[RED]], %[[RESULT]] : [[RTYPE]] -// CHECK: "xla_lhlo.terminator"() +// CHECK: "lmhlo.terminator"() // CHECK-NEXT: }) diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo index 8dbd5dab178..470ae348740 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/gather.hlo @@ -11,7 +11,7 @@ ENTRY %Gather (x: f32[100,10], y: s64[4,6]) -> f32[4,6,10] { // CHECK: func @gather(%[[ARG0:.*]]: [[TYPE0:.*]], %[[ARG1:.*]]: [[TYPE1:.*]], // CHECK-SAME: %[[RESULT:.*]]: [[RTYPE:.*]]) { -// CHECK-NEXT: "xla_lhlo.gather"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) { +// CHECK-NEXT: "lmhlo.gather"(%[[ARG0]], %[[ARG1]], %[[RESULT]]) { // CHECK-SAME: collapsed_slice_dims = dense<0> : tensor<1xi64>, // CHECK-SAME: index_vector_dim = 2 : i64, // CHECK-SAME: offset_dims = dense<2> : tensor<1xi64>, diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo index 01d125fd866..50ff5571dbe 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/imag.hlo @@ -7,5 +7,5 @@ ENTRY %Imag (x: c64[2,2]{0,1}) -> f32[2,2] { } // CHECK: func @imag(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) { -// CHECK: "xla_lhlo.imag"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () +// CHECK: "lmhlo.imag"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo index eb97667886f..1755e4b0157 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/iota.hlo @@ -6,6 +6,6 @@ HloModule Iota } // CHECK: func @iota(%[[OUT:.*]]: [[OUT_T:.*]]) { -// CHECK: "xla_lhlo.iota"(%[[OUT]]) +// CHECK: "lmhlo.iota"(%[[OUT]]) // CHECK: {iota_dimension = 0 : i64} : ([[OUT_T]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo index 3a19bc2f703..5f1156497b9 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/log.hlo @@ -7,5 +7,5 @@ ENTRY %Log (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @log(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.log"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo index 45804cf8edd..30557f13449 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/neg.hlo @@ -6,5 +6,5 @@ ENTRY %Neg (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @negate(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.negate"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.negate"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo index b1b02976a7d..559a4db4914 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/real.hlo @@ -7,5 +7,5 @@ ENTRY %Real (x: c64[2,2]{0,1}) -> f32[2,2] { } // CHECK: func @real(%[[IN:.*]]: [[BUF_C64:.*]], %[[OUT:.*]]: [[BUF_F32:.*]]) { -// CHECK: "xla_lhlo.real"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () +// CHECK: "lmhlo.real"(%[[IN]], %[[OUT]]) : ([[BUF_C64]], [[BUF_F32]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo index 7ceebd32c24..4c23a9854b1 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/reduce_window.hlo @@ -19,13 +19,13 @@ ENTRY %ReduceWindow (x: f32[128,64,112,112], y: f32[]) -> f32[128,64,56,56] { // CHECK: func @"reduce-window"( // CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[CST:%.*]]: memref, [[RES:%.*]]: [[REST:.*]]) { -// CHECK: "xla_lhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( { +// CHECK: "lmhlo.reduce_window"([[LHS:%.*]], [[RHS:%.*]], [[OUT:%.*]]) ( { // CHECK: ^bb0([[LHS:%.*]]: memref, [[RHS:%.*]]: memref, [[OUT:%.*]]: memref): // CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]] // CHECK: [[RHS_TENSOR:%.*]] = tensor_load [[RHS]] // CHECK: [[OUT_TENSOR:%.*]] = mhlo.maximum [[LHS_TENSOR]], [[RHS_TENSOR]] // CHECK: tensor_store [[OUT_TENSOR]], [[OUT]] -// CHECK: "xla_lhlo.terminator"() : () -> () +// CHECK: "lmhlo.terminator"() : () -> () // CHECK: }) { // CHECK-SAME: base_dilations = dense<1> : tensor<4xi64> // CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]> diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo index 172e3224b77..6d3afb07f56 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rem.hlo @@ -7,5 +7,5 @@ ENTRY %Rem(x: f32[2,2], y: f32[2,2]) -> f32[2,2] { } // CHECK: func @remainder(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.remainder"(%[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[TYPE]], [[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo index 44167bba987..11d18e88061 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/rsqrt.hlo @@ -7,5 +7,5 @@ ENTRY %Rsqrt (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @rsqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.rsqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo index d900f56dcb2..bf25c69c524 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select.hlo @@ -9,6 +9,6 @@ ENTRY %Select (p: pred[2,2], x: f32[2,2], y: f32[2,2]) -> f32[2,2] { } // CHECK: func @select(%[[PRED:.*]]: [[PRED_TYPE:.*]], %[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]], %[[ARG2:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.select"(%[[PRED]], %[[ARG0]], %[[ARG1]], %[[ARG2]]) : ([[PRED_TYPE]], [[TYPE]], [[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo index 93d1eff3051..46d29856828 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/select_and_scatter.hlo @@ -30,7 +30,7 @@ ENTRY %SelectAndScatter (x: f32[128,64,112,112], // CHECK: func @"select-and-scatter"( // CHECK-SAME: [[ARG:%.*]]: [[ARGT:.*]], [[SRC:%.*]]: [[SRCT:.*]], [[CST:%.*]]: memref, [[RES:%.*]]: [[REST:.*]]) { -// CHECK: "xla_lhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( { +// CHECK: "lmhlo.select_and_scatter"([[ARG]], [[SRC]], [[CST]], [[RES]]) ( { // CHECK: ^bb0([[LHS:%.*]]: memref, [[RHS:%.*]]: memref, // CHECK-SAME: [[OUT:%.*]]: memref): // CHECK: [[LHS_TENSOR:%.*]] = tensor_load [[LHS]] @@ -38,7 +38,7 @@ ENTRY %SelectAndScatter (x: f32[128,64,112,112], // CHECK: [[OUT_TENSOR:%.*]] = "mhlo.compare" // CHECK-SAME: ([[LHS_TENSOR]], [[RHS_TENSOR]]) {comparison_direction = "GE"} // CHECK: tensor_store [[OUT_TENSOR]], [[OUT]] -// CHECK: xla_lhlo.terminator +// CHECK: lmhlo.terminator // CHECK: }, { // CHECK: ^bb0([[LHS_:%.*]]: memref, [[RHS_:%.*]]: memref, // CHECK-SAME: [[OUT_:%.*]]: memref): @@ -46,7 +46,7 @@ ENTRY %SelectAndScatter (x: f32[128,64,112,112], // CHECK: [[RHS_TENSOR_:%.*]] = tensor_load [[RHS_]] // CHECK: [[OUT_TENSOR_:%.*]] = mhlo.add [[LHS_TENSOR_]], [[RHS_TENSOR_]] // CHECK: tensor_store [[OUT_TENSOR_]], [[OUT_]] -// CHECK: xla_lhlo.terminator +// CHECK: lmhlo.terminator // CHECK: }) { // CHECK-SAME: padding = dense<{{\[}}[0, 0], [0, 0], [0, 1], [0, 1]]> // CHECK-SAME: window_dimensions = dense<[1, 1, 3, 3]> diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo index 0a7afa69bab..6acadb84e17 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sign.hlo @@ -6,5 +6,5 @@ ENTRY %Sign (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @sign(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.sign"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo index 54bf947350b..4e47229397d 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/sqrt.hlo @@ -7,6 +7,6 @@ ENTRY %Sqrt (x: f32[2,2]) -> f32[2,2] { } // CHECK: func @sqrt(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.sqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.sqrt"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } diff --git a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo index ff147c9041c..681c18aed29 100644 --- a/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo +++ b/tensorflow/compiler/xla/service/mlir_gpu/tests/tanh.hlo @@ -6,5 +6,5 @@ ENTRY %Tanh (val: f32[2,2]) -> f32[2,2] { } // CHECK: func @tanh(%[[ARG0:.*]]: [[TYPE:.*]], %[[ARG1:.*]]: [[TYPE]]) { -// CHECK: "xla_lhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () +// CHECK: "lmhlo.tanh"(%[[ARG0]], %[[ARG1]]) : ([[TYPE]], [[TYPE]]) -> () // CHECK: } From ab442b907600ac81bd771978b1ed9a466cf61dd8 Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 8 Jul 2020 10:12:48 -0700 Subject: [PATCH 54/88] Rename xla_chlo dialect into chlo Following on the plan of isolating the compiler/mlir/hlo directory. PiperOrigin-RevId: 320212018 Change-Id: Iffebf0107164ebc1d2af4fab9811681058183fea --- .../mlir-hlo/Dialect/mhlo/IR/chlo_ops.h | 10 +- .../mlir-hlo/Dialect/mhlo/IR/chlo_ops.td | 6 +- .../Dialect/mhlo/transforms/rewriters.h | 4 +- .../mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc | 8 +- .../Dialect/mhlo/IR/dialect_registration.cc | 3 +- .../mhlo/transforms/chlo_legalize_to_hlo.cc | 4 +- .../transforms/chlo_legalize_to_hlo_pass.cc | 8 +- .../tests/chlo_infer_shape_type_methods.mlir | 10 +- .../chlo_legalize_to_hlo_broadcasts.mlir | 48 ++--- .../mlir/tensorflow/tests/legalize_hlo.mlir | 98 ++++----- .../compiler/mlir/xla/tests/legalize-tf.mlir | 196 +++++++++--------- .../mlir/xla/transforms/legalize_tf.cc | 80 ++++--- 12 files changed, 236 insertions(+), 239 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h index 453d755a485..1fbf55ded83 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h @@ -28,18 +28,18 @@ limitations under the License. #include "mlir/Interfaces/SideEffectInterfaces.h" // from @llvm-project namespace mlir { -namespace xla_chlo { +namespace chlo { -class XlaHloClientDialect : public Dialect { +class HloClientDialect : public Dialect { public: - explicit XlaHloClientDialect(MLIRContext *context); - static StringRef getDialectNamespace() { return "xla_chlo"; } + explicit HloClientDialect(MLIRContext *context); + static StringRef getDialectNamespace() { return "chlo"; } }; #define GET_OP_CLASSES #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.h.inc" -} // namespace xla_chlo +} // namespace chlo } // namespace mlir #endif // TENSORFLOW_COMPILER_MLIR_HLO_INCLUDE_MLIR_HLO_DIALECT_MHLO_IR_CHLO_OPS_H_ diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td index 7f23de6d8bd..a31935076e1 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.td @@ -22,7 +22,7 @@ limitations under the License. // // The typical use of this dialect is for client libraries to be able to emit // less constrained ops and rely on the conversion framework to lower any -// xla_chlo ops to canonical mhlo ops. +// chlo ops to canonical mhlo ops. // // See: https://www.tensorflow.org/xla/operation_semantics @@ -35,8 +35,8 @@ include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir-hlo/Dialect/mhlo/IR/hlo_ops_base.td" def HLOClient_Dialect : Dialect { - let name = "xla_chlo"; - let cppNamespace = "xla_chlo"; + let name = "chlo"; + let cppNamespace = "chlo"; let summary = [{ XLA Client HLO Ops }]; diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h index fbff046ac40..95ce7f36a90 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h @@ -84,14 +84,14 @@ void PopulateLhloToLLVMConversionPatterns(const LowerToLLVMOptions &options, } // namespace lmhlo -namespace xla_chlo { +namespace chlo { // Populates a collection of conversion patterns for legalizing client-HLO to // HLO. void PopulateLegalizeChloToHloPatterns(MLIRContext *context, OwningRewritePatternList *patterns); -} // namespace xla_chlo +} // namespace chlo namespace xla { diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc index e43890979e5..330738c1aac 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/chlo_ops.cc @@ -23,7 +23,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" namespace mlir { -namespace xla_chlo { +namespace chlo { template static LogicalResult Verify(T op) { @@ -263,10 +263,10 @@ BROADCAST_BINARY_OP_DEFS(BroadcastXorOp); #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/chlo_ops.cc.inc" //===----------------------------------------------------------------------===// -// xla_chlo Dialect Constructor +// chlo Dialect Constructor //===----------------------------------------------------------------------===// -XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) +HloClientDialect::HloClientDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -274,5 +274,5 @@ XlaHloClientDialect::XlaHloClientDialect(MLIRContext* context) >(); } -} // namespace xla_chlo +} // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc index 6d27d6015c2..3ec69c102ff 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -19,6 +19,5 @@ limitations under the License. // Static initialization for XLA dialect registration. static mlir::DialectRegistration mhlo_ops; -static mlir::DialectRegistration - xla_chlo_ops; +static mlir::DialectRegistration chlo_ops; static mlir::DialectRegistration lmhlo_ops; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc index aaeaca9c4d7..60f4010fed0 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo.cc @@ -25,7 +25,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/utils/broadcast_utils.h" namespace mlir { -namespace xla_chlo { +namespace chlo { namespace { @@ -235,5 +235,5 @@ void PopulateLegalizeChloToHloPatterns(MLIRContext *context, context, patterns); } -} // namespace xla_chlo +} // namespace chlo } // namespace mlir diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index b818e9a2792..c1ffd438d56 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -21,7 +21,7 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/transforms/rewriters.h" namespace mlir { -namespace xla_chlo { +namespace chlo { namespace { @@ -31,7 +31,7 @@ struct TestChloLegalizeToHloPass ConversionTarget conversionTarget(getContext()); OwningRewritePatternList conversionPatterns; - conversionTarget.addIllegalDialect(); + conversionTarget.addIllegalDialect(); // Consider the mhlo dialect legal for tests. conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. @@ -49,9 +49,9 @@ struct TestChloLegalizeToHloPass } // namespace -} // namespace xla_chlo +} // namespace chlo } // namespace mlir -static mlir::PassRegistration pass( +static mlir::PassRegistration pass( "test-xla-chlo-legalize-to-hlo", "Test pass for applying chlo -> hlo legalization patterns"); diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir index f71f58f1fe0..ab507809fdb 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_infer_shape_type_methods.mlir @@ -11,7 +11,7 @@ func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xinde // CHECK-DAG: %[[BCAST_S:.+]] = "shape.broadcast"(%[[ARG0_S]], %[[ARG1_S]]) // CHECK: %[[EXTENTS:.+]] = shape.to_extent_tensor %[[BCAST_S]] // CHECK: return %[[EXTENTS]] - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor %1 = "xla_test.reify_return_type_shapes"(%0) : (tensor) -> tensor<1xindex> return %1 : tensor<1xindex> } @@ -19,7 +19,7 @@ func @broadcast_add(%arg0: tensor, %arg1: tensor) -> tensor<1xinde // ----- // CHECK-LABEL: @complex_ranked_components func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor> { - %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = complex} %1 = "xla_test.get_return_type_components"(%0) : (tensor>) -> tensor> return %1 : tensor> @@ -28,7 +28,7 @@ func @complex_ranked_components(%arg0: tensor, %arg1: tensor) -> // ----- // CHECK-LABEL: @compare_ranked_components func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = i1} %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor return %0 : tensor @@ -37,7 +37,7 @@ func @compare_ranked_components(%arg0: tensor, %arg1: tensor) -> // ----- // CHECK-LABEL: @broadcast_add_ranked_components_r1 func @broadcast_add_ranked_components_r1(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1], element_type0 = f32} %1 = "xla_test.get_return_type_components"(%0) : (tensor) -> tensor return %1 : tensor @@ -46,7 +46,7 @@ func @broadcast_add_ranked_components_r1(%arg0: tensor, %arg1: tensor, %arg1: tensor) -> tensor { - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor // TODO: Overly broad shapes are being returned. Tighten the calculation // and update/extend these tests. // CHECK: "xla_test.return_type_components"(%0) {dims0 = [-1, -1], element_type0 = f32} diff --git a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir index 78617b7e3d0..cfd15e6a670 100644 --- a/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir +++ b/tensorflow/compiler/mlir/hlo/tests/chlo_legalize_to_hlo_broadcasts.mlir @@ -5,7 +5,7 @@ // CHECK-LABEL: @addWithoutBroadcast func @addWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.add %arg0, %arg1 - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -26,7 +26,7 @@ func @dynamicBroadcast(%arg0: tensor, %arg1: tensor) -> tensor - %0 = xla_chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor + %0 = chlo.broadcast_add %arg0, %arg1 : (tensor, tensor) -> tensor return %0 : tensor } @@ -47,7 +47,7 @@ func @dynamicBroadcastComplex(%arg0: tensor, %arg1: tensor) -> t // CHECK-NEXT: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK-NEXT: return %[[FINAL_RESULT]] : tensor> - %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> + %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor, tensor) -> tensor> return %0 : tensor> } @@ -68,7 +68,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK: shape.assuming_yield %[[RESULT]] // CHECK-NEXT: } // CHECK: return %[[FINAL_RESULT]] : tensor - %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor + %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor, tensor) -> tensor return %0 : tensor } @@ -77,7 +77,7 @@ func @dynamicBroadcastCompare(%arg0: tensor, %arg1: tensor) -> t // CHECK-LABEL: @dynamicNonScalarBroadcastDimensions func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // CHECK: mhlo.add - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -86,7 +86,7 @@ func @dynamicNonScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor< // CHECK-LABEL: @dynamicNonScalarByScalarBroadcastDimensions func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: tensor) -> tensor<1x4xf32> { // CHECK: mhlo.add - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1x4xf32>, tensor) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -95,7 +95,7 @@ func @dynamicNonScalarByScalarBroadcastDimensions(%arg0: tensor<1x4xf32>, %arg1: func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -104,7 +104,7 @@ func @dynamicNonScalarBroadcastDimensionsSizeMismatch(%arg0: tensor<1x4xf32>, %a func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: tensor<4xf32>) -> tensor<1x4xf32> { // expected-warning @+2 {{unsupported non prefix-padded dynamic rank broadcast_dimensions}} // expected-error @+1 {{failed to legalize operation}} - %0 = xla_chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> + %0 = chlo.broadcast_add %arg0, %arg1 {broadcast_dimensions = dense<2> : tensor<1xi64>} : (tensor<1x4xf32>, tensor<4xf32>) -> tensor<1x4xf32> return %0 : tensor<1x4xf32> } @@ -114,7 +114,7 @@ func @dynamicNonScalarBroadcastDimensionsMismatch(%arg0: tensor<1x4xf32>, %arg1: // CHECK-LABEL: @andWithoutBroadcast func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: mhlo.and %arg0, %arg1 - %0 = xla_chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + %0 = chlo.broadcast_and %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -122,7 +122,7 @@ func @andWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4x // CHECK-LABEL: @atan2WithoutBroadcast func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.atan2 %arg0, %arg1 - %0 = xla_chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_atan2 %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -130,7 +130,7 @@ func @atan2WithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso // CHECK-LABEL: @compareWithoutBroadcast func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xi1> { // CHECK: "mhlo.compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> - %0 = xla_chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> + %0 = chlo.broadcast_compare %arg0, %arg1 {comparison_direction = "EQ"} : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -138,7 +138,7 @@ func @compareWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // CHECK-LABEL: @complexWithoutBroadcast func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xcomplex> { // CHECK: "mhlo.complex"(%arg0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> - %0 = xla_chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> + %0 = chlo.broadcast_complex %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xcomplex> return %0 : tensor<4xcomplex> } @@ -146,7 +146,7 @@ func @complexWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // CHECK-LABEL: @divideWithoutBroadcast func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.divide %arg0, %arg1 - %0 = xla_chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_divide %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -154,7 +154,7 @@ func @divideWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tens // CHECK-LABEL: @maximumWithoutBroadcast func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.maximum %arg0, %arg1 - %0 = xla_chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_maximum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -162,7 +162,7 @@ func @maximumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // CHECK-LABEL: @minimumWithoutBroadcast func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.minimum %arg0, %arg1 - %0 = xla_chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_minimum %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -170,7 +170,7 @@ func @minimumWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> ten // CHECK-LABEL: @multiplyWithoutBroadcast func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.multiply %arg0, %arg1 - %0 = xla_chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_multiply %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -178,7 +178,7 @@ func @multiplyWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> te // CHECK-LABEL: @orWithoutBroadcast func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: mhlo.or %arg0, %arg1 - %0 = xla_chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + %0 = chlo.broadcast_or %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } @@ -186,7 +186,7 @@ func @orWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi // CHECK-LABEL: @powerWithoutBroadcast func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.power %arg0, %arg1 - %0 = xla_chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_power %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -194,7 +194,7 @@ func @powerWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tenso // CHECK-LABEL: @remainderWithoutBroadcast func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.remainder %arg0, %arg1 - %0 = xla_chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_remainder %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -202,7 +202,7 @@ func @remainderWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> t // CHECK-LABEL: @shift_leftWithoutBroadcast func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.shift_left %arg0, %arg1 - %0 = xla_chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_shift_left %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -210,7 +210,7 @@ func @shift_leftWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> // CHECK-LABEL: @shift_right_arithmeticWithoutBroadcast func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.shift_right_arithmetic %arg0, %arg1 - %0 = xla_chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_shift_right_arithmetic %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -218,7 +218,7 @@ func @shift_right_arithmeticWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor // CHECK-LABEL: @shift_right_logicalWithoutBroadcast func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.shift_right_logical %arg0, %arg1 - %0 = xla_chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_shift_right_logical %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -226,7 +226,7 @@ func @shift_right_logicalWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4x // CHECK-LABEL: @subWithoutBroadcast func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor<4xf32> { // CHECK: mhlo.subtract %arg0, %arg1 - %0 = xla_chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %0 = chlo.broadcast_subtract %arg0, %arg1 : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> return %0 : tensor<4xf32> } @@ -234,6 +234,6 @@ func @subWithoutBroadcast(%arg0: tensor<4xf32>, %arg1: tensor<4xf32>) -> tensor< // CHECK-LABEL: @xorWithoutBroadcast func @xorWithoutBroadcast(%arg0: tensor<4xi1>, %arg1: tensor<4xi1>) -> tensor<4xi1> { // CHECK: mhlo.xor %arg0, %arg1 - %0 = xla_chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> + %0 = chlo.broadcast_xor %arg0, %arg1 : (tensor<4xi1>, tensor<4xi1>) -> tensor<4xi1> return %0 : tensor<4xi1> } diff --git a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir index 9197641502e..4f044cd5eff 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/legalize_hlo.mlir @@ -2,17 +2,17 @@ func @biasAdd_NHWC(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_NCHW(%arg0: tensor<1x32x10x32xi32>, %arg1: tensor<32xi32>) -> tensor<1x32x10x32xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<1x32x10x32xi32>, tensor<32xi32>) -> tensor<1x32x10x32xi32> return %0 : tensor<1x32x10x32xi32> } func @biasAdd_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -23,12 +23,12 @@ func @add(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_add(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } func @broadcast_multi_dim_add(%arg0: tensor<4x1x1xi32>, %arg1: tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> { - %0 = "xla_chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> + %0 = "chlo.broadcast_add"(%arg0, %arg1) {broadcast_dimensions = dense<[1, 2, 3]> : tensor<3xi64>} : (tensor<4x1x1xi32>, tensor<4x4x4x4xi32>) -> tensor<4x4x4x4xi32> return %0 : tensor<4x4x4x4xi32> } @@ -38,7 +38,7 @@ func @div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -48,7 +48,7 @@ func @shift_left(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @div_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor, tensor) -> tensor return %0 : tensor } @@ -68,7 +68,7 @@ func @mul(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_mul(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_multiply"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -78,7 +78,7 @@ func @real_div(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_real_div(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -88,7 +88,7 @@ func @sub(%arg0: tensor<2xi32>) -> tensor<2xi32> { } func @broadcast_sub(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi32> { - %0 = "xla_chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> + %0 = "chlo.broadcast_subtract"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi32> return %0 : tensor<1x2xi32> } @@ -98,7 +98,7 @@ func @shift_right(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @broadcast_shift_right(%arg0: tensor<4xi32>, %arg1: tensor<2x4xi32>) -> tensor<2x4xi32> { - %0 = "xla_chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> + %0 = "chlo.broadcast_shift_right_arithmetic"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<4xi32>, tensor<2x4xi32>) -> tensor<2x4xi32> return %0 : tensor<2x4xi32> } @@ -108,12 +108,12 @@ func @and(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @and_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @and_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -123,12 +123,12 @@ func @or(%arg0: tensor<2xi1>) -> tensor<2xi1> { } func @or_broadcast(%arg0: tensor<1xi1>, %arg1: tensor<1x2xi1>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi1>, tensor<1x2xi1>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @or_dynamic(%arg0: tensor, %arg1: tensor<1xi1>) -> tensor { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor + %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi1>) -> tensor return %0 : tensor } @@ -138,12 +138,12 @@ func @bitwise_or(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_or_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_or"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_or_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_or"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -153,12 +153,12 @@ func @bitwise_and(%arg0: tensor<4xi32>, %arg1: tensor<4xi32>) -> tensor<4xi32> { } func @bitwise_and_broadcast(%arg0: tensor<1xi8>, %arg1: tensor<1x4xi8>) -> tensor<1x4xi8> { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> + %0 = "chlo.broadcast_and"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<1xi8>, tensor<1x4xi8>) -> tensor<1x4xi8> return %0 : tensor<1x4xi8> } func @bitwise_and_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_and"(%arg0, %arg1) : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -174,19 +174,19 @@ func @pow_dynamic(%arg0: tensor) -> tensor { func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { %0 = mhlo.constant dense<0> : tensor<2x3xi32> - %1 = "xla_chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %1 = "chlo.broadcast_compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> %2 = mhlo.constant dense<0> : tensor<3xi32> - %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> - %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> - %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2x3xi1>, tensor<3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<2x3xi32>) -> tensor<2x3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> %8 = mhlo.constant dense<1> : tensor<3xi32> %9 = mhlo.subtract %7, %8 : tensor<3xi32> - %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<3xi32>) -> tensor<3xi32> - %13 = "xla_chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> + %13 = "chlo.broadcast_divide"(%11, %12) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> %14 = "mhlo.select"(%4, %5, %13) : (tensor<2x3xi1>, tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> return %14 : tensor<2x3xi32> } @@ -195,14 +195,14 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 %0 = mhlo.constant dense<0> : tensor<3xi32> %1 = "mhlo.compare"(%arg0, %0) {comparison_direction = "LT"} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi1> %2 = mhlo.constant dense<0> : tensor<2x3xi32> - %3 = "xla_chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> - %4 = "xla_chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> - %5 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %3 = "chlo.broadcast_compare"(%arg1, %2) {comparison_direction = "LT"} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi1> + %4 = "chlo.broadcast_compare"(%1, %3) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<3xi1>, tensor<2x3xi1>) -> tensor<2x3xi1> + %5 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %6 = "mhlo.abs"(%arg0) : (tensor<3xi32>) -> tensor<3xi32> %7 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %8 = mhlo.constant dense<1> : tensor<2x3xi32> %9 = mhlo.subtract %7, %8 : tensor<2x3xi32> - %10 = "xla_chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> + %10 = "chlo.broadcast_add"(%6, %9) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> %11 = "mhlo.negate"(%10) : (tensor<2x3xi32>) -> tensor<2x3xi32> %12 = "mhlo.abs"(%arg1) : (tensor<2x3xi32>) -> tensor<2x3xi32> %13 = mhlo.divide %11, %12 : tensor<2x3xi32> @@ -218,8 +218,8 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { } func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - %0 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> - %1 = "xla_chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %0 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> + %1 = "chlo.broadcast_divide"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> %2 = "mhlo.floor"(%1) : (tensor<2x3xf16>) -> tensor<2x3xf16> return %2 : tensor<2x3xf16> } @@ -230,22 +230,22 @@ func @equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @equal_dynamic(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } func @equal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @equal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "EQ"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -255,17 +255,17 @@ func @notequal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @notequal_broadcast(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_broadcast_no_incompatible_shapes_error(%arg0: tensor<2xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} : (tensor<2xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } func @notequal_incompatible_shape_broadcastable(%arg0: tensor, %arg1: tensor<1xi32>) -> tensor { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {comparison_direction = "NE"} : (tensor, tensor<1xi32>) -> tensor return %0 : tensor } @@ -275,7 +275,7 @@ func @greater(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -285,7 +285,7 @@ func @greater_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_greater_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "GE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -295,7 +295,7 @@ func @less(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LT"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -305,7 +305,7 @@ func @less_equal(%arg0: tensor<2xi32>) -> tensor<2xi1> { } func @broadcast_less_equal(%arg0: tensor<1xi32>, %arg1: tensor<1x2xi32>) -> tensor<1x2xi1> { - %0 = "xla_chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> + %0 = "chlo.broadcast_compare"(%arg0, %arg1) {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "LE"} : (tensor<1xi32>, tensor<1x2xi32>) -> tensor<1x2xi1> return %0 : tensor<1x2xi1> } @@ -326,35 +326,35 @@ func @const() -> tensor<2xi32> { func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor - %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> return %1 : tensor<1xi32> } func @relu_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor - %1 = "xla_chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %1 = "chlo.broadcast_maximum"(%0, %arg0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %1 : tensor } func @relu6(%arg0: tensor<1xi32>) -> tensor<1xi32> { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor - %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> - %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<1xi32>, tensor) -> tensor<1xi32> return %3 : tensor<1xi32> } func @relu6_unranked(%arg0: tensor) -> tensor { %0 = mhlo.constant dense<0> : tensor %1 = mhlo.constant dense<6> : tensor - %2 = "xla_chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor - %3 = "xla_chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %2 = "chlo.broadcast_minimum"(%arg0, %1) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + %3 = "chlo.broadcast_maximum"(%2, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor return %3 : tensor } func @relu_grad(%arg0: tensor<4x8xf32>, %arg1: tensor) -> tensor<4x8xf32> { %0 = mhlo.constant dense<0.000000e+00> : tensor - %1 = "xla_chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor + %1 = "chlo.broadcast_compare"(%arg1, %0) {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "GT"} : (tensor, tensor) -> tensor %2 = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32> %3 = "mhlo.select"(%1, %arg0, %2) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> return %3 : tensor<4x8xf32> diff --git a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir index 62718203dac..0e3fcbb6364 100644 --- a/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir +++ b/tensorflow/compiler/mlir/xla/tests/legalize-tf.mlir @@ -54,7 +54,7 @@ func @fusedBatchNormV3_training(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8xf32> // CHECK: "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 1 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: %[[VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} : (tuple, tensor<8xf32>, tensor<8xf32>>) -> tensor<8xf32> // CHECK: mhlo.constant - // CHECK: xla_chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK: chlo.broadcast_multiply %[[VAR]], {{.*}} : (tensor<8xf32>, tensor) -> tensor<8xf32> return %0#0 : tensor<8x8x8x8xf32> } @@ -75,18 +75,18 @@ func @fusedBatchNormV3_training_exponentialAvgFactor(%arg0: tensor<8x8x8x8xf32>, // CHECK-DAG: %[[BATCH_VAR:.*]] = "mhlo.get_tuple_element"(%[[RESULT0]]) {index = 2 : i32} // CHECK: %[[FACTOR:.*]] = mhlo.constant dense<1.00195694> - // CHECK: %[[CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] + // CHECK: %[[CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BATCH_VAR]], %[[FACTOR]] // CHECK-DAG: %[[ALPHA:.*]] = mhlo.constant dense<0.199999988> // CHECK-DAG: %[[BETA:.*]] = mhlo.constant dense<8.000000e-01> - // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg3 - // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] - // CHECK: %[[NEW_BATCH_MEAN:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] + // CHECK: %[[ALPHA_MUL_OLD_MEAN:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg3 + // CHECK: %[[BETA_MUL_BATCH_MEAN:.*]] = chlo.broadcast_multiply %[[BETA]], %[[BATCH_MEAN]] + // CHECK: %[[NEW_BATCH_MEAN:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_MEAN]], %[[BETA_MUL_BATCH_MEAN]] - // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = xla_chlo.broadcast_multiply %[[ALPHA]], %arg4 - // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = xla_chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] - // CHECK: %[[NEW_BATCH_VAR:.*]] = xla_chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] + // CHECK: %[[ALPHA_MUL_OLD_VAR:.*]] = chlo.broadcast_multiply %[[ALPHA]], %arg4 + // CHECK: %[[BETA_MUL_CORRECTED_VAR:.*]] = chlo.broadcast_multiply %[[BETA]], %[[CORRECTED_VAR]] + // CHECK: %[[NEW_BATCH_VAR:.*]] = chlo.broadcast_add %[[ALPHA_MUL_OLD_VAR]], %[[BETA_MUL_CORRECTED_VAR]] // CHECK: return %[[NEW_BATCH_MEAN]], %[[NEW_BATCH_VAR]], %[[BATCH_MEAN]], %[[BATCH_VAR]] return %0#1, %0#2, %0#3, %0#4 : tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> @@ -134,7 +134,7 @@ func @fusedBatchNormGrad_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<8x // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -193,7 +193,7 @@ func @fusedBatchNormGradV2_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -280,7 +280,7 @@ func @fusedBatchNormGradV3_noTraining(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor< // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<3> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -367,7 +367,7 @@ func @fusedBatchNormGradV3_noTraining_NCHW(%arg0: tensor<8x8x8x8xf32>, %arg1: te // CHECK-NEXT: %[[act:.*]] = "mhlo.convert"(%arg1) : (tensor<8x8x8x8xf32>) -> tensor<8x8x8x8xf32> // CHECK-NEXT: %[[eps:.*]] = mhlo.constant dense<1.000000e-03> : tensor - // CHECK-NEXT: %[[add:.*]] = xla_chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> + // CHECK-NEXT: %[[add:.*]] = chlo.broadcast_add %arg4, %[[eps]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<8xf32>, tensor) -> tensor<8xf32> // CHECK-NEXT: %[[scr1:.*]] = "mhlo.rsqrt"(%[[add]]) : (tensor<8xf32>) -> tensor<8xf32> // CHECK: %[[bcast_arg3:.+]] = "mhlo.dynamic_broadcast_in_dim"(%arg3, {{.*}}) {broadcast_dimensions = dense<1> : tensor<1xi64>} : (tensor<8xf32>, tensor<4xindex>) -> tensor<8x8x8x8xf32> @@ -498,19 +498,19 @@ func @unary_einsum(%arg0: tensor<2x3xf32>) -> tensor<2x2xf32> { // CHECK-LABEL: func @floordiv_broadcast_i32 func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -520,19 +520,19 @@ func @floordiv_broadcast_i32(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> te // CHECK-LABEL: func @floordiv_reverse_broadcast_i32 func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] + // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -541,7 +541,7 @@ func @floordiv_reverse_broadcast_i32(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32 // CHECK-LABEL: func @floordiv_f32 func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { - // CHECK-NEXT: %[[DIV:.*]] = xla_chlo.broadcast_divide %arg0, %arg0 + // CHECK-NEXT: %[[DIV:.*]] = chlo.broadcast_divide %arg0, %arg0 // CHECK-NEXT: %[[FLOOR:.*]] = "mhlo.floor"(%[[DIV]]) // CHECK-NEXT: return %[[FLOOR]] : tensor<2xf32> %0 = "tf.FloorDiv"(%arg0, %arg0) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32> @@ -552,7 +552,7 @@ func @floordiv_f32(%arg0: tensor<2xf32>) -> tensor<2xf32> { func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-NEXT: mhlo.convert // CHECK-NEXT: mhlo.convert - // CHECK-NEXT: xla_chlo.broadcast_divide + // CHECK-NEXT: chlo.broadcast_divide // CHECK-NEXT: mhlo.floor // CHECK-NEXT: mhlo.convert // CHECK-NEXT: return @@ -562,7 +562,7 @@ func @floordiv_bf16(%arg0: tensor<2xbf16>) -> tensor<2xbf16> { // CHECK-LABEL: func @floordiv_f16_broadcast func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> tensor<2x3xf16> { - // CHECK-NEXT: xla_chlo.broadcast_divide + // CHECK-NEXT: chlo.broadcast_divide // CHECK-NEXT: mhlo.floor // CHECK-NEXT: return %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor<2x3xf16>, tensor<3xf16>) -> tensor<2x3xf16> @@ -572,19 +572,19 @@ func @floordiv_f16_broadcast(%arg0: tensor<2x3xf16>, %arg1: tensor<3xf16>) -> te // CHECK-LABEL: func @floordiv_dynamic func @floordiv_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { // CHECK-DAG: [[ZEROS1:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare %arg0, [[ZEROS1]] {comparison_direction = "LT"} // CHECK-DAG: [[ZEROS2:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} - // CHECK-DAG: [[DIV1:%.+]] = xla_chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZEROS2]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[CMP1]], [[CMP2]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "EQ"} + // CHECK-DAG: [[DIV1:%.+]] = chlo.broadcast_divide %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ABS1:%.+]] = "mhlo.abs"(%arg0) // CHECK-DAG: [[ABS2:%.+]] = "mhlo.abs"(%arg1) // CHECK-DAG: [[ONES:%.+]] = mhlo.constant dense<1> - // CHECK-DAG: [[SUB:%.+]] = xla_chlo.broadcast_subtract [[ABS2]], [[ONES]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[SUB:%.+]] = chlo.broadcast_subtract [[ABS2]], [[ONES]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[ABS1]], [[SUB]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[NEG:%.+]] = "mhlo.negate"([[ADD]]) // CHECK-DAG: [[ABS3:%.+]] = "mhlo.abs"(%arg1) - // CHECK-DAG: [[DIV2:%.+]] = xla_chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[DIV2:%.+]] = chlo.broadcast_divide [[NEG]], [[ABS3]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[CMP3]], [[DIV1]], [[DIV2]]) // CHECK: return [[SELECT]] %0 = "tf.FloorDiv"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -600,15 +600,15 @@ func @floordiv_unranked(%arg0: tensor<*xi32>, %arg1: tensor<*xi32>) -> tensor<*x // CHECK-LABEL: func @floormod_broadcast_numerator func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32> @@ -617,15 +617,15 @@ func @floormod_broadcast_numerator(%arg0: tensor<3xi32>, %arg1: tensor<2x3xi32>) // CHECK-LABEL: func @floormod_broadcast_denominator func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32>) -> tensor<2x3xi32> { - // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor<2x3xi32>, tensor<3xi32>) -> tensor<2x3xi32> @@ -634,15 +634,15 @@ func @floormod_broadcast_denominator(%arg0: tensor<2x3xi32>, %arg1: tensor<3xi32 // CHECK-LABEL: func @floormod_dynamic func @floormod_dynamic(%arg0: tensor, %arg1: tensor) -> tensor { - // CHECK-DAG: [[REM:%.+]] = xla_chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[REM:%.+]] = chlo.broadcast_remainder %arg0, %arg1 {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[ZL:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP1:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} + // CHECK-DAG: [[CMP1:%.+]] = chlo.broadcast_compare [[REM]], [[ZL]] {comparison_direction = "NE"} // CHECK-DAG: [[ZR:%.+]] = mhlo.constant dense<0> - // CHECK-DAG: [[CMP2:%.+]] = xla_chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} - // CHECK-DAG: [[CMP3:%.+]] = xla_chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} - // CHECK-DAG: [[CMP4:%.+]] = xla_chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} - // CHECK-DAG: [[AND:%.+]] = xla_chlo.broadcast_and [[CMP1]], [[CMP4]] - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} + // CHECK-DAG: [[CMP2:%.+]] = chlo.broadcast_compare %arg1, [[ZR]] {comparison_direction = "LT"} + // CHECK-DAG: [[CMP3:%.+]] = chlo.broadcast_compare [[REM]], [[ZR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>, comparison_direction = "LT"} + // CHECK-DAG: [[CMP4:%.+]] = chlo.broadcast_compare [[CMP2]], [[CMP3]] {broadcast_dimensions = dense<1> : tensor<1xi64>, comparison_direction = "NE"} + // CHECK-DAG: [[AND:%.+]] = chlo.broadcast_and [[CMP1]], [[CMP4]] + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add %arg1, [[REM]] {broadcast_dimensions = dense<1> : tensor<1xi64>} // CHECK-DAG: [[SELECT:%.+]] = "mhlo.select"([[AND]], [[ADD]], [[REM]]) // CHECK-NEXT: return [[SELECT]] %0 = "tf.FloorMod"(%arg0, %arg1) : (tensor, tensor) -> tensor @@ -979,10 +979,10 @@ func @matrix_band_part(%arg0: tensor<64x64xbf16>, %arg1: tensor, %arg2: ten // CHECK: %[[X:.*]] = "mhlo.iota"() {iota_dimension = 1 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<64x64xbf16> // CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<64x64xbf16> - // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> + // CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<64x64xbf16>) -> tensor<64x64xi1> // CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> + // CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<64x64xbf16>, tensor) -> tensor<64x64xi1> // CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<64x64xi1> @@ -1000,10 +1000,10 @@ func @matrix_band_part_2(%arg0: tensor<12x24x48xbf16>, %arg1: tensor, %arg2 // CHECK: %[[Y:.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} : () -> tensor<24x48xbf16> // CHECK: %[[OFFSET:.*]] = mhlo.subtract %[[X]], %[[Y]] : tensor<24x48xbf16> - // CHECK: %[[G:.*]] = xla_chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> + // CHECK: %[[G:.*]] = chlo.broadcast_compare %[[F]], %[[OFFSET]] {comparison_direction = "LE"} : (tensor, tensor<24x48xbf16>) -> tensor<24x48xi1> // CHECK: %[[H:.*]] = "mhlo.convert"(%[[D]]) : (tensor) -> tensor - // CHECK: %[[I:.*]] = xla_chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> + // CHECK: %[[I:.*]] = chlo.broadcast_compare %[[OFFSET]], %[[H]] {comparison_direction = "LE"} : (tensor<24x48xbf16>, tensor) -> tensor<24x48xi1> // CHECK: %[[J:.*]] = mhlo.and %[[G]], %[[I]] : tensor<24x48xi1> // CHECK: %[[ZERO2:.*]] = mhlo.constant dense<0.000000e+00> : tensor<12x24x48xbf16> @@ -1315,7 +1315,7 @@ func @stateful_pcall_multi_in_out(%arg0: tensor, %arg1: tensor) -> (te // CHECK-LABEL: func @relu func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor<1xi32>) -> tensor<1xi32> %0 = "tf.Relu"(%arg0) : (tensor<1xi32>) -> tensor<1xi32> return %0: tensor<1xi32> } @@ -1323,7 +1323,7 @@ func @relu(%arg0: tensor<1xi32>) -> tensor<1xi32> { // CHECK-LABEL: func @relu_unranked func @relu_unranked(%arg0: tensor) -> tensor { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0> : tensor - // CHECK: xla_chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor + // CHECK: chlo.broadcast_maximum %[[ZERO]], %arg0 {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor, tensor) -> tensor %0 = "tf.Relu"(%arg0) : (tensor) -> tensor return %0: tensor } @@ -1351,7 +1351,7 @@ func @relu6_unranked(%arg0: tensor) -> tensor { func @relu_grad(%gradients: tensor<4x8xf32>, %features: tensor) -> tensor<4x8xf32> { // CHECK-DAG: %[[ZERO_SCALAR:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK-DAG: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor<4x8xf32> - // CHECK-DAG: %[[PRED:.*]] = xla_chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor, tensor) -> tensor + // CHECK-DAG: %[[PRED:.*]] = chlo.broadcast_compare %[[FEATURES]], %[[ZERO_SCALAR]] {comparison_direction = "GT"} : (tensor, tensor) -> tensor // CHECK-DAG: %[[RESULT:.*]] = "mhlo.select"(%[[PRED]], %[[GRADIENTS]], %[[ZERO]]) : (tensor, tensor<4x8xf32>, tensor<4x8xf32>) -> tensor<4x8xf32> // CHECK-DAG: return %[[RESULT]] : tensor<4x8xf32> %2 = "tf.ReluGrad"(%gradients, %features) : (tensor<4x8xf32>, tensor) -> tensor<4x8xf32> @@ -2473,10 +2473,10 @@ func @strided_slice_nonconstant_begin_end(%arg0: tensor, %arg1: tensor<32x1 // CHECK-DAG-SAME: start_indices = dense<0> : tensor<1xi64>, // CHECK-DAG-SAME: strides = dense<1> : tensor<1xi64>} : (tensor<1xi32>) -> tensor<1xi32> // CHECK-NEXT: %[[INDEX2:.*]] = "mhlo.reshape"(%[[INDEX]]) : (tensor<1xi32>) -> tensor - // CHECK-NEXT: %[[CMP:.*]] = xla_chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] + // CHECK-NEXT: %[[CMP:.*]] = chlo.broadcast_compare %[[INDEX2]], %[[ZERO]] // CHECK-DAG-SAME: {comparison_direction = "LT"} : (tensor, tensor) -> tensor // CHECK-NEXT: %[[DIM:.*]] = mhlo.constant dense<32> : tensor - // CHECK-NEXT: %[[WRAP:.*]] = xla_chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor + // CHECK-NEXT: %[[WRAP:.*]] = chlo.broadcast_add %[[DIM]], %[[INDEX2]] : (tensor, tensor) -> tensor // CHECK-NEXT: %[[INDEX3:.*]] = "mhlo.select"(%[[CMP]], %[[WRAP]], %[[INDEX2]]) : // CHECK-DAG-SAME: (tensor, tensor, tensor) -> tensor // CHECK-NEXT: %[[SLICED:.*]] = "mhlo.dynamic-slice" @@ -2605,7 +2605,7 @@ func @mean(%arg0: tensor<4x8xf16>) -> tensor<4x1xf16> { // CHECK: "mhlo.return"(%[[REDUCE_BODY_RESULT]]) : (tensor) -> () // CHECK: }) {dimensions = dense<1> : tensor<1xi64>} : (tensor<4x8xf32>, tensor) -> tensor<4xf32> // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<8.000000e+00> : tensor - // CHECK: %[[MEAN:.*]] = xla_chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> + // CHECK: %[[MEAN:.*]] = chlo.broadcast_divide %[[REDUCED]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<4xf32>, tensor) -> tensor<4xf32> // CHECK: %[[CAST_BACK:.*]] = "mhlo.convert"(%[[MEAN]]) : (tensor<4xf32>) -> tensor<4xf16> // CHECK: %[[RESULT:.*]] = "mhlo.reshape"(%[[CAST_BACK]]) : (tensor<4xf16>) -> tensor<4x1xf16> // CHECK: return %[[RESULT]] : tensor<4x1xf16> @@ -2909,8 +2909,8 @@ func @rng_std_normal(%arg0: tensor<3xi32>) -> tensor<12x?x64xf32> { func @range(%arg0: tensor, %arg1: tensor) -> tensor<5xf32> { %1 = "tf.Const"() {device = "", dtype = "tfdtype$DT_FLOAT", name = "range/limit", value = dense<5.000000e+00> : tensor} : () -> tensor // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota" - // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK: xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[DELTA]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK: chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} %3 = "tf.Range"(%arg0, %1, %arg1) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor<5xf32> return %3 : tensor<5xf32> } @@ -2929,8 +2929,8 @@ func @range_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor) // CHECK-DAG: [[IOTA:%.+]] = "mhlo.dynamic_iota"([[RESHAPE]]) {iota_dimension = 0 : i64} // CHECK-DAG: [[CONVERT3:%.+]] = "mhlo.convert"(%arg0) // CHECK-DAG: [[CONVERT4:%.+]] = "mhlo.convert"(%arg2) - // CHECK-DAG: [[MUL:%.+]] = xla_chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -2951,8 +2951,8 @@ func @range_int_dynamic(%arg0: tensor, %arg1: tensor, %arg2: tensor : tensor<0xi64>} - // CHECK-DAG: [[ADD:%.+]] = xla_chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.+]] = chlo.broadcast_multiply [[IOTA]], [[CONVERT4]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[ADD:%.+]] = chlo.broadcast_add [[MUL]], [[CONVERT3]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} %2 = "tf.Range"(%arg0, %arg1, %arg2) {Tidx = "tfdtype$DT_FLOAT", device = "", name = "range"} : (tensor, tensor, tensor) -> tensor // CHECK: return [[ADD]] @@ -2966,12 +2966,12 @@ func @linspace_static(%arg0: tensor, %arg1: tensor) -> tensor<4xf32> { // CHECK-DAG: [[NUM_CAST:%.*]] = tensor_cast [[NUM]] // CHECK-DAG: [[NUM_F32:%.*]] = "mhlo.convert"([[NUM_CAST]]) // CHECK-DAG: [[ONE:%.*]] = mhlo.constant dense<1.000000e+00> - // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = xla_chlo.broadcast_subtract [[NUM_F32]], [[ONE]] - // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = xla_chlo.broadcast_subtract [[STOP]], [[START]] - // CHECK-DAG: [[STEP:%.*]] = xla_chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] + // CHECK-DAG: [[STEP_DENOMINATOR:%.*]] = chlo.broadcast_subtract [[NUM_F32]], [[ONE]] + // CHECK-DAG: [[STEP_NUMERATOR:%.*]] = chlo.broadcast_subtract [[STOP]], [[START]] + // CHECK-DAG: [[STEP:%.*]] = chlo.broadcast_divide [[STEP_NUMERATOR]], [[STEP_DENOMINATOR]] // CHECK-DAG: [[IOTA:%.*]] = "mhlo.iota"() {iota_dimension = 0 : i64} - // CHECK-DAG: [[MUL:%.*]] = xla_chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} - // CHECK-DAG: [[LINSPACE:%.*]] = xla_chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[MUL:%.*]] = chlo.broadcast_multiply [[IOTA]], [[STEP]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} + // CHECK-DAG: [[LINSPACE:%.*]] = chlo.broadcast_add [[MUL]], [[START]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} // CHECK: return [[LINSPACE]] %0 = "tf.Const"() {_output_shapes = ["tfshape$"], device = "", dtype = i32, value = dense<4> : tensor} : () -> tensor %1 = "tf.LinSpace"(%arg0, %arg1, %0) : (tensor, tensor, tensor) -> tensor<4xf32> @@ -3266,13 +3266,13 @@ func @size_ranked(%input: tensor<2x?x8xf32>) -> (tensor) { // CHECK: %[[CONST:.*]] = mhlo.constant dense<1> // CHECK: %[[DIM_0:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 0 - // CHECK: %[[MUL_0:.*]] = xla_chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] + // CHECK: %[[MUL_0:.*]] = chlo.broadcast_multiply %[[CONST]], %[[DIM_0]] // CHECK: %[[DIM_1:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 1 - // CHECK: %[[MUL_1:.*]] = xla_chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] + // CHECK: %[[MUL_1:.*]] = chlo.broadcast_multiply %[[MUL_0]], %[[DIM_1]] // CHECK: %[[DIM_2:.*]] = "mhlo.get_dimension_size"(%[[INPUT]]) // CHECK-SAME: dimension = 2 - // CHECK: %[[MUL_2:.*]] = xla_chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] + // CHECK: %[[MUL_2:.*]] = chlo.broadcast_multiply %[[MUL_1]], %[[DIM_2]] %size = "tf.Size"(%input) {T = "tfdtype$DT_FLOAT", out_type = "tfdtype$DT_INT32"} : (tensor<2x?x8xf32>) -> tensor // CHECK: return %[[MUL_2]] return %size : tensor @@ -3789,7 +3789,7 @@ func @random_shuffle_3D(%input: tensor<4x?x16xf32>) -> tensor<4x?x16xf32> { // CHECK: [[INDICES1:%.*]] = "mhlo.dynamic-update-slice"([[INDICES]], [[TGT_IDX]], [[IV]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[INDICES2:%.*]] = "mhlo.dynamic-update-slice"([[INDICES1]], [[SRC_IDX]], [[SWP]]) : (tensor<4xi32>, tensor<1xi32>, tensor) -> tensor<4xi32> // CHECK: [[ONE:%.*]] = mhlo.constant dense<1> : tensor - // CHECK: [[NEW_IV:%.*]] = xla_chlo.broadcast_add [[IV]], [[ONE]] + // CHECK: [[NEW_IV:%.*]] = chlo.broadcast_add [[IV]], [[ONE]] // CHECK: [[NEW_TUPLE:%.*]] = "mhlo.tuple"([[NEW_IV]], [[SWAPS]], [[INDICES2]]) // CHECK: "mhlo.return"([[NEW_TUPLE]]) // CHECK: }) : (tuple, tensor<4xi32>, tensor<4xi32>>) -> tuple, tensor<4xi32>, tensor<4xi32>> @@ -3822,7 +3822,7 @@ func @avgpool_valid_padding(%arg0: tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> // CHECK: "mhlo.return"([[ADD]]) // CHECK: }) {window_dimensions = dense<[1, 2, 2, 1]> : tensor<4xi64>, window_strides = dense<[1, 4, 4, 1]> : tensor<4xi64>} : (tensor<2x12x20x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[COUNT:%.+]] = mhlo.constant dense<4.000000e+00> : tensor - // CHECK: [[DIV:%.+]] = xla_chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> + // CHECK: [[DIV:%.+]] = chlo.broadcast_divide [[REDUCE]], [[COUNT]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<2x3x5x7xf32>, tensor) -> tensor<2x3x5x7xf32> // CHECK: [[CONV16:%.+]] = "mhlo.convert"([[DIV]]) : (tensor<2x3x5x7xf32>) -> tensor<2x3x5x7xf16> // CHECK: return [[CONV16]] %0 = "tf.AvgPool"(%arg0) {data_format = "NHWC", ksize = [1, 2, 2, 1], padding = "VALID", strides = [1, 4, 4, 1]} : (tensor<2x12x20x7xf16>) -> tensor<2x3x5x7xf16> @@ -3844,7 +3844,7 @@ func @avgpool_same_padding(%arg0: tensor<2x13x25x7xf32>) -> tensor<2x4x7x7xf32> // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xf32>) -> tensor<10x24x32x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] // CHECK_SAME: broadcast_dimensions = dense<[]> // CHECK_SAME: -> tensor<10x12x16x64xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) @@ -3876,7 +3876,7 @@ func @avgpool_grad_valid_padding(%grad: tensor<10x12x16x64xf32>) -> tensor<10x24 // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x8x12x16x64xf32>) -> tensor<10x8x24x32x64xf32> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] {broadcast_dimensions = dense<[]> : tensor<0xi64>} : (tensor<10x8x12x16x64xf32>, tensor) -> tensor<10x8x12x16x64xf32> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) // CHECK-SAME: edge_padding_high = dense<[0, 0, 1, 1, 0]> // CHECK-SAME: edge_padding_low = dense<[0, 0, 1, 1, 0]> @@ -4059,7 +4059,7 @@ func @avgpool_3d_grad_ncdwh_format(%grad: tensor<2x9x8x4x7xf32>) -> tensor<2x9x8 // CHECK-SAME: %[[OUT_GRAD:.*]]: tensor<10x12x16x64xbf16>) -> tensor<10x24x32x64xbf16> { // CHECK: %[[ZERO:.*]] = mhlo.constant dense<0.000000e+00> : tensor // CHECK: %[[DIVISOR:.*]] = mhlo.constant dense<4.000000e+00> : tensor -// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = xla_chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] +// CHECK: %[[OUT_GRAD_DIVIDED:.*]] = chlo.broadcast_divide %[[OUT_GRAD]], %[[DIVISOR]] // CHECK-SAME: broadcast_dimensions = dense<[]> // CHECK-SAME: -> tensor<10x12x16x64xbf16> // CHECK: %[[REDUCE_WINDOW_INPUT:.*]] = "mhlo.pad"(%[[OUT_GRAD_DIVIDED]], %[[ZERO]]) @@ -4236,10 +4236,10 @@ func @softplus_f16(%arg0: tensor<8x16xf16>) -> tensor<8x16xf16> { // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.220700e-04> : tensor // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) @@ -4256,10 +4256,10 @@ func @softplus_bf16(%arg0: tensor<8x16xbf16>) -> tensor<8x16xbf16> { // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<7.812500e-03> : tensor // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) @@ -4276,10 +4276,10 @@ func @softplus_f32(%arg0: tensor<8x16xf32>) -> tensor<8x16xf32> { // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<1.1920929E-7> : tensor // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) @@ -4296,10 +4296,10 @@ func @softplus_f64(%arg0: tensor<8x16xf64>) -> tensor<8x16xf64> { // CHECK-DAG: [[EPSILON:%.*]] = mhlo.constant dense<2.2204460492503131E-16> : tensor // CHECK-DAG: [[EPSILON_LOG:%.*]] = "mhlo.log"([[EPSILON]]) // CHECK-DAG: [[TWO:%.*]] = mhlo.constant dense<2.000000e+00> : tensor - // CHECK: [[THRESHOLD:%.*]] = xla_chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] + // CHECK: [[THRESHOLD:%.*]] = chlo.broadcast_add [[EPSILON_LOG]], [[TWO]] // CHECK: [[NEG_THRESHOLD:%.*]] = "mhlo.negate"([[THRESHOLD]]) - // CHECK-DAG: [[COMPARE_GT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} - // CHECK-DAG: [[COMPARE_LT:%.*]] = xla_chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} + // CHECK-DAG: [[COMPARE_GT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[NEG_THRESHOLD]] {comparison_direction = "GT"} + // CHECK-DAG: [[COMPARE_LT:%.*]] = chlo.broadcast_compare [[FEATURES]], [[THRESHOLD]] {comparison_direction = "LT"} // CHECK-DAG: [[FEATURES_EXP_LOG:%.*]] = "mhlo.log_plus_one"([[FEATURES_EXP]]) // CHECK: [[ELSE_SELECT:%.*]] = "mhlo.select"([[COMPARE_LT]], [[FEATURES_EXP]], [[FEATURES_EXP_LOG]]) // CHECK: [[ENTRY_SELECT:%.*]] = "mhlo.select"([[COMPARE_GT]], [[FEATURES]], [[ELSE_SELECT]]) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 119ba51c60d..2f856b79e47 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -715,7 +715,7 @@ static void CreateWhile32(Location loc, int num_iterations, auto one = builder->create(loc, builder->getI32IntegerAttr(1)); auto scalar_broadcast_dims = GetI64ElementsAttr({}, builder); - auto plus_one = builder->create( + auto plus_one = builder->create( loc, old_values[0], one, scalar_broadcast_dims); // Prepend with the updated loop induction variable. new_values.insert(new_values.begin(), plus_one); @@ -1483,7 +1483,7 @@ class ConvertFusedBatchNormGradBase RankedTensorType scalar_float = RankedTensorType::get({}, kernel_type); auto epsilon = rewriter.create( loc, DenseFPElementsAttr::get(scalar_float, {op.epsilon()})); - auto add_op = rewriter.create( + auto add_op = rewriter.create( loc, var, epsilon.getResult(), scalar_broadcast_dims); Value scratch1 = rewriter.create(loc, add_op); @@ -1601,7 +1601,7 @@ class ConvertFusedBatchNormV3Op auto factor_const_op = rewriter.create( op.getLoc(), rewriter.getFloatAttr(scale_element_type, factor)); - Value corrected_variance = rewriter.create( + Value corrected_variance = rewriter.create( op.getLoc(), batch_variance.getType(), batch_variance, factor_const_op, /*broadcast_dimensions=*/DenseIntElementsAttr()); @@ -1621,26 +1621,24 @@ class ConvertFusedBatchNormV3Op rewriter.getFloatAttr(mean_element_type, exponential_avg_factor)); // new_running_mean = alpha * old_mean + beta * batch_mean. - auto alpha_mul_old_mean = rewriter.create( + auto alpha_mul_old_mean = rewriter.create( op.getLoc(), op.mean().getType(), alpha, op.mean(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_mean = rewriter.create( + auto beta_mul_batch_mean = rewriter.create( op.getLoc(), batch_mean.getType(), beta, batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); - batch_mean = rewriter.create( + batch_mean = rewriter.create( op.getLoc(), alpha_mul_old_mean, beta_mul_batch_mean, /*broadcast_dimensions=*/DenseIntElementsAttr()); // new_running_variance = alpha * old_variance + beta * batch_variance. - auto alpha_mul_old_variance = rewriter.create( + auto alpha_mul_old_variance = rewriter.create( op.getLoc(), op.variance().getType(), alpha, op.variance(), /*broadcast_dimensions=*/DenseIntElementsAttr()); - auto beta_mul_batch_variance = - rewriter.create( - op.getLoc(), corrected_variance.getType(), beta, - corrected_variance, - /*broadcast_dimensions=*/DenseIntElementsAttr()); - corrected_variance = rewriter.create( + auto beta_mul_batch_variance = rewriter.create( + op.getLoc(), corrected_variance.getType(), beta, corrected_variance, + /*broadcast_dimensions=*/DenseIntElementsAttr()); + corrected_variance = rewriter.create( op.getLoc(), alpha_mul_old_variance, beta_mul_batch_variance, /*broadcast_dimensions=*/DenseIntElementsAttr()); } @@ -1810,7 +1808,7 @@ class ConvertAvgPoolOp : public OpRewritePattern { Value divisor = GetScalarConstOfType(sum_element_type, op.getLoc(), count, &rewriter); auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - Value result = rewriter.create( + Value result = rewriter.create( op.getLoc(), result_type, reduce, divisor, scalar_broadcast_dims); // Convert back if we enlarged the element type's bitwidth. @@ -1914,7 +1912,7 @@ class ConvertAvgPoolGradOp : public OpRewritePattern { Value divisor = GetScalarConstOfType(element_type, loc, window_count, &rewriter); auto scalar_broadcast_dims = GetI64ElementsAttr({}, &rewriter); - out_grad_divided = rewriter.create( + out_grad_divided = rewriter.create( loc, out_grad_type, out_grad, divisor, scalar_broadcast_dims); } else { assert(op.padding() == "SAME"); @@ -2335,7 +2333,7 @@ class ConvertSizeOp : public OpRewritePattern { auto dim = rewriter.create( op.getLoc(), result_type, input, rewriter.getIntegerAttr(rewriter.getIntegerType(32), i)); - size = rewriter.create( + size = rewriter.create( op.getLoc(), size->getResult(0), dim.getResult(), /*DenseIntElementsAttr=*/DenseIntElementsAttr()); } @@ -3021,10 +3019,10 @@ class ConvertRangeOp : public OpRewritePattern { auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, op.delta(), xla::getBroadcastDimensionsAttr(&rewriter, iota, op.delta())); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -3101,10 +3099,10 @@ class ConvertDynamicRangeOp : public OpRewritePattern { auto iota = rewriter.create( op.getLoc(), result_type, reshape, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, delta_out_cast, xla::getBroadcastDimensionsAttr(&rewriter, iota, delta_cast)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, start_out_cast, xla::getBroadcastDimensionsAttr(&rewriter, scaled, start_out_cast)); return success(); @@ -3152,7 +3150,7 @@ class ConvertLinSpaceOp : public OpRewritePattern { int64_t num = (*num_attr.begin()).getSExtValue(); // Calculate the scaling that needs to be applied to the iota. - auto step_numerator = rewriter.create( + auto step_numerator = rewriter.create( op.getLoc(), op.start().getType(), op.stop(), op.start(), xla::getBroadcastDimensionsAttr(&rewriter, op.stop(), op.start())); Value step_denominator = rewriter.create( @@ -3160,11 +3158,11 @@ class ConvertLinSpaceOp : public OpRewritePattern { if (num > 1) { Value one = GetScalarConstOfType(result_type.getElementType(), op.getLoc(), 1, &rewriter); - step_denominator = rewriter.create( + step_denominator = rewriter.create( op.getLoc(), step_denominator.getType(), step_denominator, one, xla::getBroadcastDimensionsAttr(&rewriter, step_denominator, one)); } - auto step = rewriter.create( + auto step = rewriter.create( op.getLoc(), step_numerator.getType(), step_numerator, step_denominator, xla::getBroadcastDimensionsAttr(&rewriter, step_numerator, step_denominator)); @@ -3172,10 +3170,10 @@ class ConvertLinSpaceOp : public OpRewritePattern { // Scale the iota and add the offset. auto iota = rewriter.create(op.getLoc(), result_type, rewriter.getI64IntegerAttr(0)); - auto scaled = rewriter.create( + auto scaled = rewriter.create( op.getLoc(), result_type, iota, step, xla::getBroadcastDimensionsAttr(&rewriter, iota, step)); - rewriter.replaceOpWithNewOp( + rewriter.replaceOpWithNewOp( op, result_type, scaled, op.start(), xla::getBroadcastDimensionsAttr(&rewriter, scaled, op.start())); return success(); @@ -3251,7 +3249,7 @@ class GenericConvertReductionOp : public OpRewritePattern { auto divisor = GetScalarConstOfType(reduce_element_type, loc, divisor_count, &rewriter); auto broadcast_dims = GetI64ElementsAttr({}, &rewriter); - result = rewriter.create( + result = rewriter.create( loc, result, divisor.getResult(), broadcast_dims); } @@ -5008,11 +5006,11 @@ class ConvertQrOp : public OpRewritePattern { Value iota = builder->create( loc, RankedTensorType::get({m}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value gtk = builder->create( + Value gtk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("GT", builder->getContext())); gtk = builder->create(loc, gtk, x_type.getElementType()); - Value x_after_k = builder->create( + Value x_after_k = builder->create( loc, x, gtk, GetI64ElementsAttr({minor_dim}, builder)); Value x_after_k_sq = builder->create(loc, x_after_k, x_after_k); // sigma = np.dot(x[k+1:], x[k+1:]) @@ -5024,15 +5022,15 @@ class ConvertQrOp : public OpRewritePattern { Value mu = builder->create( loc, builder->create(loc, alpha_sq, sigma.getResult(0))); - Value sigma_is_zero = builder->create( + Value sigma_is_zero = builder->create( loc, sigma.getResult(0), zero, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); - Value alpha_is_negative = builder->create( + Value alpha_is_negative = builder->create( loc, alpha, zero, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); auto batch_size_one = builder->create( loc, alpha.getType(), one, GetI64ElementsAttr(batch_dims, builder)); - Value signed_mu = builder->create( + Value signed_mu = builder->create( loc, builder->create(loc, mu.getType(), alpha_is_negative, batch_size_one, @@ -5050,7 +5048,7 @@ class ConvertQrOp : public OpRewritePattern { divisor = builder->create(loc, divisor.getType(), sigma_is_zero, batch_size_one, divisor); - Value eqk = builder->create( + Value eqk = builder->create( loc, iota, k, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); eqk = builder->create(loc, eqk, x_type.getElementType()); @@ -5064,7 +5062,7 @@ class ConvertQrOp : public OpRewritePattern { // Form v as [0, 0, ..., 1] ++ x[k+1:] / divisor // If sigma is zero, x[k+1:] is zero, so use any non-zero divisor. // Note that the add performs a degenerate broadcast. - *v = builder->create( + *v = builder->create( loc, e_k, StaticBinaryBroadcast(loc, x_after_k, divisor, GetI64ElementsAttr(batch_dim_ids, builder), @@ -5154,12 +5152,12 @@ class ConvertQrOp : public OpRewritePattern { auto iota = builder->create( loc, RankedTensorType::get({m, 1}, builder->getIntegerType(32)), builder->getI64IntegerAttr(0)); - Value predecessor_mask = builder->create( + Value predecessor_mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("LT", builder->getContext())); predecessor_mask = builder->create(loc, predecessor_mask, a_type.getElementType()); - Value mask = builder->create( + Value mask = builder->create( loc, iota, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); mask = builder->create(loc, mask, a_type.getElementType()); @@ -5189,7 +5187,7 @@ class ConvertQrOp : public OpRewritePattern { loc, RankedTensorType::get(a_type.getShape(), builder->getIntegerType(32)), builder->getI64IntegerAttr(minor_dim + 1)); - Value xa_mask = builder->create( + Value xa_mask = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); a = builder->create(loc, a_type, xa_mask, new_x, a); @@ -5226,7 +5224,7 @@ class ConvertQrOp : public OpRewritePattern { loc, taus.getType(), taus_zeros, GetI64ElementsAttr(taus.getType().cast().getShape(), builder)); - Value taus_mask = builder->create( + Value taus_mask = builder->create( loc, iota_n, j, GetI64ElementsAttr({}, builder), StringAttr::get("EQ", builder->getContext())); auto taus_update = builder->create( @@ -5311,7 +5309,7 @@ class ConvertQrOp : public OpRewritePattern { loc, vs.getType(), zero, GetI64ElementsAttr(vs.getType().cast().getShape(), builder)); - auto compare = builder->create( + auto compare = builder->create( loc, iota_mn, j, GetI64ElementsAttr({}, builder), StringAttr::get("GE", builder->getContext())); auto y = builder->create(loc, vs.getType(), compare, zero, vs); @@ -5459,14 +5457,14 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, // Populate with CHLO->HLO lowerings to account for TF ops legalized to // CHLO first. if (legalize_chlo) { - xla_chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); + chlo::PopulateLegalizeChloToHloPatterns(context, &patterns); } ConversionTarget target(*context); if (legalize_chlo) { - target.addIllegalDialect(); + target.addIllegalDialect(); } else { - target.addLegalDialect(); + target.addLegalDialect(); } target.addLegalDialect(); target.addLegalDialect(); From 61d0aaf7ddc8221040dc41bd9b9795568ee8b220 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Wed, 8 Jul 2020 10:13:27 -0700 Subject: [PATCH 55/88] Update v1 only saved_model_experimental_test test with proper reason. PiperOrigin-RevId: 320212160 Change-Id: Ia72ea7c532d334c24ffb1d528db790245e389bea --- .../python/keras/saving/saved_model_experimental_test.py | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/keras/saving/saved_model_experimental_test.py b/tensorflow/python/keras/saving/saved_model_experimental_test.py index 281a58a1076..aa920c5642a 100644 --- a/tensorflow/python/keras/saving/saved_model_experimental_test.py +++ b/tensorflow/python/keras/saving/saved_model_experimental_test.py @@ -45,7 +45,8 @@ from tensorflow.python.saved_model import model_utils from tensorflow.python.training import training as training_module -@test_util.run_deprecated_v1 # Removed in v2. +@test_util.run_v1_only( + 'keras.experimental.load_from_saved_model is supported only in V1.') class TestModelSavingandLoading(parameterized.TestCase, test.TestCase): def _save_model_dir(self, dirname='saved_model'): @@ -279,7 +280,8 @@ def load_model(sess, path, mode): return inputs, outputs, meta_graph_def -@test_util.run_deprecated_v1 # Removed in v2. +@test_util.run_v1_only( + 'keras.experimental.export_saved_model is supported only in V1.') class TestModelSavedModelExport(test.TestCase, parameterized.TestCase): def _save_model_dir(self, dirname='saved_model'): From 2ebad1d7bea3d8b1db6738a3b85ba83a2bbd927a Mon Sep 17 00:00:00 2001 From: Frank Chen Date: Wed, 8 Jul 2020 10:14:19 -0700 Subject: [PATCH 56/88] Call TfTpu_Initialize() to initialize relevant bits of the TPU framework when loading the TPU library PiperOrigin-RevId: 320212388 Change-Id: I7a614c34407373af207f98db8f68cc342a176438 --- tensorflow/core/tpu/tpu_api_dlsym_initializer.cc | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc index 9080f789d4c..f3a857671c6 100644 --- a/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc +++ b/tensorflow/core/tpu/tpu_api_dlsym_initializer.cc @@ -53,9 +53,10 @@ Status InitializeTpuLibrary(void* library_handle) { // loaded. We do not want to register a TPU platform in XLA without the // supporting library providing the necessary APIs. if (s.ok()) { - // TODO(frankchn): Make initialization actually work - // Initialize TPU platform when the platform code is loaded from a library. - // InitializeApiFn()->TfTpu_InitializeFn(); + void (*initialize_fn)(); + initialize_fn = reinterpret_cast( + dlsym(library_handle, "TfTpu_Initialize")); + (*initialize_fn)(); RegisterTpuPlatform(); RegisterTpuSystemDevice(); From 17939ba17279795c7369e8cfa72d8f59cba3152e Mon Sep 17 00:00:00 2001 From: Mehdi Amini Date: Wed, 8 Jul 2020 10:19:13 -0700 Subject: [PATCH 57/88] Rename XlaHloDialect class into MhloDialect following the recent dialect namespace renaming PiperOrigin-RevId: 320213526 Change-Id: I31a3391837f300a58ea0f1c0991b5ba3573a4a31 --- .../hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h | 4 ++-- .../hlo/lib/Dialect/mhlo/IR/dialect_registration.cc | 2 +- .../compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc | 11 +++++------ .../mhlo/transforms/chlo_legalize_to_hlo_pass.cc | 2 +- .../Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc | 2 +- .../mhlo/transforms/materialize_broadcasts_pass.cc | 2 +- .../mhlo/transforms/xla_transform_unranked_hlo.cc | 2 +- .../mlir/tensorflow/utils/compile_mlir_util.cc | 2 +- .../compiler/mlir/tools/kernel_gen/cubin_creator.cc | 2 +- tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc | 2 +- .../compiler/mlir/xla/tests/mlir_hlo_builder_test.cc | 2 +- .../compiler/mlir/xla/transforms/legalize_tf.cc | 2 +- tensorflow/compiler/tf2xla/mlir_tf2xla.cc | 2 +- 13 files changed, 18 insertions(+), 19 deletions(-) diff --git a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h index aabe7bc9d3c..f368202ebad 100644 --- a/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h +++ b/tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h @@ -39,9 +39,9 @@ class OpBuilder; namespace mhlo { -class XlaHloDialect : public Dialect { +class MhloDialect : public Dialect { public: - explicit XlaHloDialect(MLIRContext *context); + explicit MhloDialect(MLIRContext *context); static StringRef getDialectNamespace() { return "mhlo"; } // Registered hook to materialize a constant operation from a given attribute diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc index 3ec69c102ff..d007e10b35b 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/dialect_registration.cc @@ -18,6 +18,6 @@ limitations under the License. #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/lhlo_ops.h" // Static initialization for XLA dialect registration. -static mlir::DialectRegistration mhlo_ops; +static mlir::DialectRegistration mhlo_ops; static mlir::DialectRegistration chlo_ops; static mlir::DialectRegistration lmhlo_ops; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc index dddc13c3065..b2786c3ab9d 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/IR/hlo_ops.cc @@ -62,9 +62,8 @@ namespace mlir { #include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_structs.cc.inc" namespace mhlo { -Operation* XlaHloDialect::materializeConstant(OpBuilder& builder, - Attribute value, Type type, - Location loc) { +Operation* MhloDialect::materializeConstant(OpBuilder& builder, Attribute value, + Type type, Location loc) { // HLO dialect constants only support ElementsAttr unlike standard dialect // constant which supports all attributes. if (value.isa()) @@ -2128,7 +2127,7 @@ struct HLOInlinerInterface : public DialectInlinerInterface { // mhlo Dialect Constructor //===----------------------------------------------------------------------===// -XlaHloDialect::XlaHloDialect(MLIRContext* context) +MhloDialect::MhloDialect(MLIRContext* context) : Dialect(getDialectNamespace(), context) { addOperations< #define GET_OP_LIST @@ -2140,7 +2139,7 @@ XlaHloDialect::XlaHloDialect(MLIRContext* context) // allowUnknownOperations(); } -Type XlaHloDialect::parseType(DialectAsmParser& parser) const { +Type MhloDialect::parseType(DialectAsmParser& parser) const { StringRef data_type; if (parser.parseKeyword(&data_type)) return Type(); @@ -2149,7 +2148,7 @@ Type XlaHloDialect::parseType(DialectAsmParser& parser) const { return nullptr; } -void XlaHloDialect::printType(Type type, DialectAsmPrinter& os) const { +void MhloDialect::printType(Type type, DialectAsmPrinter& os) const { if (type.isa()) { os << "token"; return; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc index c1ffd438d56..4aefe3a1f24 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/chlo_legalize_to_hlo_pass.cc @@ -33,7 +33,7 @@ struct TestChloLegalizeToHloPass conversionTarget.addIllegalDialect(); // Consider the mhlo dialect legal for tests. - conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); conversionTarget.addLegalDialect(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc index adf54828ba7..4ee45d56a8e 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/hlo_legalize_to_lhlo.cc @@ -413,7 +413,7 @@ struct HloLegalizeToLhlo target.addIllegalOp(); target.addLegalOp(); target.addLegalOp(); - target.addIllegalDialect(); + target.addIllegalDialect(); BufferAssignmentTypeConverter converter; auto isMemRefType = [](Type type) { return type.isa(); }; diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc index 5d89bf24eaa..5a64ec5ea85 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/materialize_broadcasts_pass.cc @@ -34,7 +34,7 @@ struct TestMaterializeBroadcastsPass OwningRewritePatternList conversionPatterns; // Consider the mhlo dialect legal for tests. - conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc index 282f2681b49..53947855cc7 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/xla_transform_unranked_hlo.cc @@ -152,7 +152,7 @@ struct TransformUnrankedHloPass // Setup conversion target. MLIRContext &ctx = getContext(); ConversionTarget target(ctx); - target.addLegalDialect(); target.addLegalOp(); AddLegalOpOnRankedTensor(&target); diff --git a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc index 6b2469c0364..5e548da55f1 100644 --- a/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc +++ b/tensorflow/compiler/mlir/tensorflow/utils/compile_mlir_util.cc @@ -253,7 +253,7 @@ static void RegisterDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); - mlir::registerDialect(); + mlir::registerDialect(); return true; }(); (void)init_once; diff --git a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc index 348dec47ad2..1f511e27d9e 100644 --- a/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc +++ b/tensorflow/compiler/mlir/tools/kernel_gen/cubin_creator.cc @@ -88,7 +88,7 @@ struct MaterializeBroadcastsPass mlir::OwningRewritePatternList conversionPatterns; // Consider the mhlo dialect legal for tests. - conversionTarget.addLegalDialect(); + conversionTarget.addLegalDialect(); // The conversion uses helpers from the Standard dialect. conversionTarget.addLegalDialect(); diff --git a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc index 3b2f3091eb8..39ba58ebe48 100644 --- a/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc +++ b/tensorflow/compiler/mlir/xla/ir/mlir_hlo_builder.cc @@ -33,7 +33,7 @@ namespace xla { static std::string GetMlirOpName(HloOpcode opcode) { std::string op_name = HloOpcodeString(opcode); absl::c_replace(op_name, '-', '_'); - return mlir::mhlo::XlaHloDialect::getDialectNamespace().str() + "." + op_name; + return mlir::mhlo::MhloDialect::getDialectNamespace().str() + "." + op_name; } static std::string ToString(mlir::Type ty) { diff --git a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc index a75367ed309..1a3f0c16247 100644 --- a/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc +++ b/tensorflow/compiler/mlir/xla/tests/mlir_hlo_builder_test.cc @@ -48,7 +48,7 @@ class XlaBuilderTest : public ::testing::Test { xla_builder_(name_, builder_, module_->getLoc()) {} string SetupTest() { - mlir::registerDialect(); + mlir::registerDialect(); return ::testing::UnitTest::GetInstance()->current_test_info()->name(); } diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc index 2f856b79e47..6326aaf1868 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf.cc @@ -5466,7 +5466,7 @@ LogicalResult legalizeTF(Operation *op, bool allow_partial_conversion, } else { target.addLegalDialect(); } - target.addLegalDialect(); + target.addLegalDialect(); target.addLegalDialect(); target.addLegalDialect(); target.addLegalOp(); diff --git a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc index 9e6ecc9cac8..abaeb305104 100644 --- a/tensorflow/compiler/tf2xla/mlir_tf2xla.cc +++ b/tensorflow/compiler/tf2xla/mlir_tf2xla.cc @@ -95,7 +95,7 @@ static void RegisterDialects() { mlir::registerDialect(); mlir::registerDialect(); mlir::registerDialect(); - mlir::registerDialect(); + mlir::registerDialect(); mlir::registerDialect(); return true; }(); From 69ea5a77f97380f6796c20a41df4b00c847db01a Mon Sep 17 00:00:00 2001 From: Haoyu Zhang Date: Wed, 8 Jul 2020 10:20:52 -0700 Subject: [PATCH 58/88] Adjust error reporting and messages with remote targets. * When resetting the context (set_server_def), use LOG_AND_RETURN_IF_ERROR; * Log (instead of directly return) errors when creating remote contexts fails, to avoid the CHECK failures; * Explain when register function on remote workers could fail. PiperOrigin-RevId: 320213865 Change-Id: Ie6dac323e6265edc9079e22be3453a3b05949943 --- tensorflow/c/eager/c_api.cc | 33 ++++++++++++++----- .../core/common_runtime/eager/context.cc | 8 ++--- 2 files changed, 28 insertions(+), 13 deletions(-) diff --git a/tensorflow/c/eager/c_api.cc b/tensorflow/c/eager/c_api.cc index 4be3cdd7c2d..70acd710166 100644 --- a/tensorflow/c/eager/c_api.cc +++ b/tensorflow/c/eager/c_api.cc @@ -337,10 +337,13 @@ tensorflow::Status CreateRemoteContexts( }); } counter.Wait(); + tensorflow::StatusGroup sg; for (int i = 0; i < num_remote_workers; i++) { - TF_RETURN_IF_ERROR(statuses[i]); + if (TF_PREDICT_FALSE(!statuses[i].ok())) { + sg.Update(statuses[i]); + } } - return tensorflow::Status::OK(); + return sg.as_summary_status(); } tensorflow::Status UpdateRemoteContexts( @@ -611,10 +614,21 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( // Initialize remote eager workers. if (reset_context) { - LOG_AND_RETURN_IF_ERROR(CreateRemoteContexts( + const tensorflow::Status s = CreateRemoteContexts( ctx, remote_workers, context_id, context_view_id, keep_alive_secs, server_def, remote_eager_workers.get(), context->Executor().Async(), - context->LazyCopyFunctionRemoteInputs(), base_request)); + context->LazyCopyFunctionRemoteInputs(), base_request); + // NOTE: the remote tasks could fail after `GetAllRemoteDevices` and cause + // the CreateRemoteContexts to fail. We currently only log instead of + // directly returning the error, since returning here will cause the server + // object to be destroyed (which currently CHECK-fails). The client will + // see additional errors if ops are subsequently sent to the failed workers. + if (TF_PREDICT_FALSE(!s.ok())) { + LOG(ERROR) << "Error when creating contexts on remote targets: " + << s.error_message() + << "\nExecuting remote ops or functions on these remote " + "targets will fail."; + } } else { // The master's context_view_id will be incremented by one // the UpdateRemoteMaster call later. We want all new workers and @@ -644,15 +658,16 @@ tensorflow::Status UpdateTFE_ContextWithServerDef( grpc_server->worker_env()->rendezvous_mgr->Find(context_id); auto* device_mgr = grpc_server->worker_env()->device_mgr; std::shared_ptr worker_session; - TF_RETURN_IF_ERROR(grpc_server->worker_env()->session_mgr->CreateSession( - session_name, server_def, base_request.cluster_device_attributes(), - true)); - TF_RETURN_IF_ERROR( + LOG_AND_RETURN_IF_ERROR( + grpc_server->worker_env()->session_mgr->CreateSession( + session_name, server_def, base_request.cluster_device_attributes(), + true)); + LOG_AND_RETURN_IF_ERROR( grpc_server->worker_env()->session_mgr->WorkerSessionForSession( session_name, &worker_session)); // Initialize remote tensor communication based on worker session. - TF_RETURN_IF_ERROR(r->Initialize(worker_session.get())); + LOG_AND_RETURN_IF_ERROR(r->Initialize(worker_session.get())); tensorflow::DistributedFunctionLibraryRuntime* cluster_flr = tensorflow::eager::CreateClusterFLR(context_id, context, diff --git a/tensorflow/core/common_runtime/eager/context.cc b/tensorflow/core/common_runtime/eager/context.cc index 970dc40e38c..ae992f9d6f1 100644 --- a/tensorflow/core/common_runtime/eager/context.cc +++ b/tensorflow/core/common_runtime/eager/context.cc @@ -669,8 +669,8 @@ Status EagerContext::MaybeRegisterFunctionRemotely(const FunctionDef& fdef) { if (!status.ok()) { LOG(ERROR) << "Failed to register function remotely due to " << status.error_message() - << "\nThis shouldn't happen, please file a bug to " - "tensorflow team."; + << "\nThis could happen if the remote target has been " + "disconnected from the client."; } delete response; }); @@ -713,8 +713,8 @@ Status EagerContext::RegisterExistingFunctionsOnRemoteWorkers( if (!s.ok()) { LOG(ERROR) << "Failed to register function remotely due to " << s.error_message() - << "\nThis shouldn't happen, please file a bug to " - "tensorflow team."; + << "\nThis could happen if the remote target has been " + "disconnected from the client."; } }); } From b709aa20647994ba063fe5ab7b2cbf37cdb2bbbb Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 10:37:27 -0700 Subject: [PATCH 59/88] Fix typo in a message. PiperOrigin-RevId: 320217596 Change-Id: Iba6a9102ba341a87911f96b2fad44db88fc26678 --- tensorflow/compiler/xla/service/hlo_rematerialization.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/xla/service/hlo_rematerialization.cc b/tensorflow/compiler/xla/service/hlo_rematerialization.cc index bfc6769660a..2166ecdd890 100644 --- a/tensorflow/compiler/xla/service/hlo_rematerialization.cc +++ b/tensorflow/compiler/xla/service/hlo_rematerialization.cc @@ -746,7 +746,7 @@ Status MemoryUsageTracker::EndInstruction() { Buffer& buffer = buffers_.at(buffer_id); buffer.unfinished_user_count--; CHECK_GE(buffer.unfinished_user_count, 0) - << buffer.ToString() << " has negative unfinished use count."; + << buffer.ToString() << " has negative unfinished user count."; if (buffer.unfinished_user_count == 0) { // Buffer is now dead. VLOG(3) << " " << buffer.ToString() << " is now dead."; From 736b8b8b827e6bd1f4a0193bec3d0abfe6b697db Mon Sep 17 00:00:00 2001 From: Saurabh Saxena Date: Wed, 8 Jul 2020 10:39:38 -0700 Subject: [PATCH 60/88] Allow computing gradients of If/While in imported graphs. PiperOrigin-RevId: 320218171 Change-Id: I3f030716123b4e9ce3c2dc53b7c30112dc05a908 --- tensorflow/python/__init__.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/tensorflow/python/__init__.py b/tensorflow/python/__init__.py index d68119644c1..698a0d120c1 100644 --- a/tensorflow/python/__init__.py +++ b/tensorflow/python/__init__.py @@ -50,6 +50,7 @@ from tensorflow.python.layers import layers from tensorflow.python.module import module from tensorflow.python.ops import bincount_ops from tensorflow.python.ops import bitwise_ops as bitwise +from tensorflow.python.ops import cond_v2 from tensorflow.python.ops import gradient_checker_v2 from tensorflow.python.ops import image_ops as image from tensorflow.python.ops import manip_ops as manip @@ -58,6 +59,7 @@ from tensorflow.python.ops import nn from tensorflow.python.ops import ragged from tensorflow.python.ops import sets from tensorflow.python.ops import stateful_random_ops +from tensorflow.python.ops import while_v2 from tensorflow.python.ops.distributions import distributions from tensorflow.python.ops.linalg import linalg from tensorflow.python.ops.linalg.sparse import sparse From 1ca7b3ffda3ebb714ccaf2331b85ef993c5be6c8 Mon Sep 17 00:00:00 2001 From: Lucy Fox Date: Wed, 8 Jul 2020 10:42:17 -0700 Subject: [PATCH 61/88] Update whitelist to allowlist in comment. Noticed an instance of whitelist that didn't get changed (likely missed due to hyphen) so submitting a quick fix. PiperOrigin-RevId: 320218738 Change-Id: I0ff69dccac31ccd514800e32acf04c2a60caf9ea --- .../compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc index 916e7db33e3..d25b38d9ece 100644 --- a/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc +++ b/tensorflow/compiler/mlir/xla/transforms/legalize_tf_with_tf2xla.cc @@ -76,7 +76,7 @@ template using InlinedVector = tensorflow::gtl::InlinedVector; // non-absl ok static bool IsOpAllowlisted(Operation* op) { - // White-listed TensorFlow ops are known to have well behaved tf2xla kernels + // Allowlisted TensorFlow ops are known to have well behaved tf2xla kernels // building valid MLIR using MlirHloBuilder. // TODO(hinsu): Drop explicit allowlist when MLIR based bridge is enabled for // all tf2xla kernels. From fed98ed671a7f7fdbf8cec8e1ad1928302f9d3eb Mon Sep 17 00:00:00 2001 From: Alexander Belyaev Date: Wed, 8 Jul 2020 10:42:56 -0700 Subject: [PATCH 62/88] Add mlir_cpu_runner tests infra for CHLO->LHLO->LLVM lowering. PiperOrigin-RevId: 320218897 Change-Id: I07e22cf1448cbe9ac4fa303fb817d40617ab633c --- tensorflow/compiler/mlir/hlo/BUILD | 2 ++ .../Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/tensorflow/compiler/mlir/hlo/BUILD b/tensorflow/compiler/mlir/hlo/BUILD index 918dc5833aa..c0a35615ce1 100644 --- a/tensorflow/compiler/mlir/hlo/BUILD +++ b/tensorflow/compiler/mlir/hlo/BUILD @@ -669,6 +669,8 @@ cc_library( ":lhlo_legalize_to_llvm", # build-cleaner: keep ":xla_materialize_broadcasts", # build-cleaner: keep ":xla_unfuse_batch_norm", # build-cleaner: keep + "@llvm-project//mlir:AffineToStandardTransforms", + "@llvm-project//mlir:CFGTransforms", "@llvm-project//mlir:IR", "@llvm-project//mlir:InferTypeOpInterface", "@llvm-project//mlir:LLVMDialect", diff --git a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc index 033bbaf210e..d6cda99a912 100644 --- a/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc +++ b/tensorflow/compiler/mlir/hlo/lib/Dialect/mhlo/transforms/lhlo_legalize_to_llvm_pass.cc @@ -13,6 +13,8 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "mlir/Conversion/AffineToStandard/AffineToStandard.h" // from @llvm-project +#include "mlir/Conversion/SCFToStandard/SCFToStandard.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVM.h" // from @llvm-project #include "mlir/Conversion/StandardToLLVM/ConvertStandardToLLVMPass.h" // from @llvm-project #include "mlir/Dialect/LLVMIR/LLVMDialect.h" // from @llvm-project @@ -38,6 +40,9 @@ class TestLhloToLLVMPass populateStdToLLVMConversionPatterns(converter, patterns); PopulateLhloToLLVMConversionPatterns( LowerToLLVMOptions::getDefaultOptions(), &converter, &patterns); + mlir::populateLoopToStdConversionPatterns(patterns, &getContext()); + + mlir::populateAffineToStdConversionPatterns(patterns, m.getContext()); ConversionTarget target(getContext()); target.addLegalDialect(); From a24f4580739d2ae26b8dad26672c543d4c757e21 Mon Sep 17 00:00:00 2001 From: Lucy Fox Date: Wed, 8 Jul 2020 10:55:25 -0700 Subject: [PATCH 63/88] Lower tf.IsNan op to an equivalent computation using tf.Equal op. PiperOrigin-RevId: 320221675 Change-Id: Ie97659a6b4150cce5af920f42c36c48e8c7b07d7 --- tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir | 9 +++++++++ .../compiler/mlir/tensorflow/transforms/lower_tf.td | 8 ++++++++ 2 files changed, 17 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir index c04f034ede6..3215055a249 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/lower_tf.mlir @@ -86,6 +86,15 @@ func @mul_no_nan(%arg0: tensor<2x3xf32>, %arg1: tensor<3xf32>) -> tensor<2x3xf32 return %0 : tensor<2x3xf32> } +// CHECK-LABEL: @is_nan +func @is_nan(%arg0: tensor<3x4xf32>) -> tensor<3x4xi1> { + // CHECK: %[[NAN:.*]] = "tf.Const"() {value = dense<0x7FC00000> : tensor} : () -> tensor + // CHECK: %[[RESULT:.*]] = "tf.Equal"(%arg0, %[[NAN]]) {incompatible_shape_error = true} : (tensor<3x4xf32>, tensor) -> tensor<3x4xi1> + %0 = "tf.IsNan"(%arg0) : (tensor<3x4xf32>) -> tensor<3x4xi1> + // CHECK: return %[[RESULT]] + return %0 : tensor<3x4xi1> +} + // CHECK-LABEL: func @fill // CHECK-SAME: (%[[ARG0:.*]]: tensor<*xi64>, %[[ARG1:.*]]: tensor<*xf32>) func @fill(%arg0: tensor<*xi64>, %arg1: tensor<*xf32>) -> tensor<*xf32> { diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td index acf9cd27b47..6b7d7178ab6 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.td @@ -154,6 +154,14 @@ foreach fromToBinPair = [[TF_DivNoNanOp, TF_DivOp], def LowerFillOp : Pat<(TF_FillOp $dims, $value), (TF_BroadcastToOp $value, $dims)>; +//===----------------------------------------------------------------------===// +// NaN op patterns. +//===----------------------------------------------------------------------===// + +def LowerIsNanOp : Pat<(TF_IsNanOp $x), + (TF_EqualOp $x, (TF_ConstOp:$nan (GetScalarNanOfType $x)), + /*incompatible_shape_error*/ConstBoolAttrTrue)>; + //===----------------------------------------------------------------------===// // L2Loss op patterns. //===----------------------------------------------------------------------===// From 21ed77b7b5d7257c503cc7c5cb75cb6f2f872303 Mon Sep 17 00:00:00 2001 From: Ken Franko Date: Wed, 8 Jul 2020 10:58:30 -0700 Subject: [PATCH 64/88] Correctly set the experimental_io_device when restoring variable from a checkpoint. PiperOrigin-RevId: 320222381 Change-Id: I30187c7777ab8056e48004ef5e4ae747edc32227 --- tensorflow/python/training/tracking/base.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/python/training/tracking/base.py b/tensorflow/python/training/tracking/base.py index 9337adbf88a..47fbdddd4d9 100644 --- a/tensorflow/python/training/tracking/base.py +++ b/tensorflow/python/training/tracking/base.py @@ -293,9 +293,10 @@ class CheckpointPosition(object): checkpoint_key = serialized_tensor.checkpoint_key dtype = self._checkpoint.dtype_map[checkpoint_key] base_type = dtype.base_dtype + io_device = self._checkpoint.options.experimental_io_device or "cpu:0" with ops.init_scope(): - with ops.device("/cpu:0"): - # Run the restore itself on the CPU. + with ops.device(io_device): + # Run the restore itself on the io_device(CPU or specified). value, = io_ops.restore_v2( prefix=self._checkpoint.save_path_tensor, tensor_names=[checkpoint_key], From 73166ba1f6cb531583b9db4ee0d68b30291d0e85 Mon Sep 17 00:00:00 2001 From: Anna R Date: Wed, 8 Jul 2020 11:07:28 -0700 Subject: [PATCH 65/88] Update TF_AllocateOutput to just call ctx->allocate_output now that TF_Tensor just wraps a Tensor. PiperOrigin-RevId: 320224601 Change-Id: Ib0d3c3f9cd1e5a74d103dedcc17118e6a7c92817 --- tensorflow/c/kernels.cc | 25 ++++++++++++++++--------- 1 file changed, 16 insertions(+), 9 deletions(-) diff --git a/tensorflow/c/kernels.cc b/tensorflow/c/kernels.cc index a0ed0d9f245..3021a38e888 100644 --- a/tensorflow/c/kernels.cc +++ b/tensorflow/c/kernels.cc @@ -248,15 +248,22 @@ TF_Tensor* TF_AllocateOutput(TF_OpKernelContext* context, int index, size_t len, TF_Status* status) { TF_SetStatus(status, TF_OK, ""); auto* cc_ctx = reinterpret_cast<::tensorflow::OpKernelContext*>(context); - tensorflow::AllocatorAttributes attr = cc_ctx->output_alloc_attr(index); - auto* allocator = cc_ctx->get_allocator(attr); - void* data = tensorflow::allocate_tensor("TF_AllocateOutput", len, allocator); - TF_Tensor* result = TF_NewTensor(dtype, dims, num_dims, data, len, - tensorflow::deallocate_buffer, allocator); - TF_SetOutput(context, index, result, status); - if (TF_GetCode(status) != TF_OK) { - TF_DeleteTensor(result); + + static_assert(sizeof(int64_t) == sizeof(tensorflow::int64), + "64-bit int types should match in size"); + tensorflow::gtl::ArraySlice dimarray( + reinterpret_cast(dims), num_dims); + tensorflow::Tensor* tensor; + tensorflow::Status s = cc_ctx->allocate_output( + index, tensorflow::TensorShape(dimarray), &tensor); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); return nullptr; } - return result; + TF_Tensor* tf_tensor = TF_TensorFromTensor(*tensor, &s); + if (!s.ok()) { + ::tensorflow::Set_TF_Status_from_Status(status, s); + return nullptr; + } + return tf_tensor; } From e0d896af5c7f1c287c36dbefab955a8c599bb41a Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 11:13:59 -0700 Subject: [PATCH 66/88] Expose Interpreter from InterpreterWrapper. PiperOrigin-RevId: 320226019 Change-Id: I20a0fec86ccdfab9785e9940ad2845d7261bcf34 --- tensorflow/lite/python/interpreter.py | 21 +++++++++++++++++++ .../interpreter_wrapper/interpreter_wrapper.h | 7 +++++++ 2 files changed, 28 insertions(+) diff --git a/tensorflow/lite/python/interpreter.py b/tensorflow/lite/python/interpreter.py index f4a9d96da3f..35e05d8c8c9 100644 --- a/tensorflow/lite/python/interpreter.py +++ b/tensorflow/lite/python/interpreter.py @@ -526,6 +526,27 @@ class Interpreter(object): def reset_all_variables(self): return self._interpreter.ResetVariableTensors() + # Experimental and subject to change. + def _native_interpreter(self): + """Returns the underlying InterpreterWrapper object. + + This allows users to extend tflite.Interpreter's functionality in custom cpp + function. For example, + at cpp level: + void SomeNewFeature(InterpreterWrapper* wrapper) { + // Get access to tflite::Interpreter + auto* interpreter = wrapper->interpreter(); + // ... + } + at python level: + def some_new_feature(interpreter): + _cpp_to_py_wrapper.SomeNewFeature(interpreter._native_interpreter()) + + Note: This approach is fragile. Users must guarantee the C++ extension build + is consistent with the tflite.Interpreter's underlying C++ build. + """ + return self._interpreter + class InterpreterWithCustomOps(Interpreter): """Interpreter interface for TensorFlow Lite Models that accepts custom ops. diff --git a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h index b799a3067f6..5580eaa0f4b 100644 --- a/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h +++ b/tensorflow/lite/python/interpreter_wrapper/interpreter_wrapper.h @@ -92,6 +92,13 @@ class InterpreterWrapper { // Adds a delegate to the interpreter. PyObject* ModifyGraphWithDelegate(TfLiteDelegate* delegate); + // Experimental and subject to change. + // + // Returns a pointer to the underlying interpreter. + tflite_api_dispatcher::Interpreter* interpreter() { + return interpreter_.get(); + } + private: // Helper function to construct an `InterpreterWrapper` object. // It only returns InterpreterWrapper if it can construct an `Interpreter`. From eae9ba7557779e957212d6a9d215e3ea23fa9db6 Mon Sep 17 00:00:00 2001 From: Advait Jain Date: Wed, 8 Jul 2020 11:23:26 -0700 Subject: [PATCH 67/88] Make ParseOpData a stub function that returns an error for TFLM. PiperOrigin-RevId: 320227976 Change-Id: Iaa13812b519eb8926e00b7660f36f69eeae112ef --- .../lite/core/api/flatbuffer_conversions.cc | 1408 +++++++++-------- 1 file changed, 721 insertions(+), 687 deletions(-) diff --git a/tensorflow/lite/core/api/flatbuffer_conversions.cc b/tensorflow/lite/core/api/flatbuffer_conversions.cc index 4d243f9a033..2977b709bb3 100644 --- a/tensorflow/lite/core/api/flatbuffer_conversions.cc +++ b/tensorflow/lite/core/api/flatbuffer_conversions.cc @@ -131,6 +131,699 @@ TfLitePadding ConvertPadding(Padding padding) { return kTfLitePaddingUnknown; } +TfLiteStatus ParseOpDataTfLite(const Operator* op, BuiltinOperator op_type, + ErrorReporter* error_reporter, + BuiltinDataAllocator* allocator, + void** builtin_data) { + auto parseLSHProjectionType = [](LSHProjectionType type) { + switch (type) { + case LSHProjectionType_SPARSE: + return kTfLiteLshProjectionSparse; + case LSHProjectionType_DENSE: + return kTfLiteLshProjectionDense; + default: + return kTfLiteLshProjectionUnknown; + } + }; + auto parseCombinerType = [](CombinerType type) { + switch (type) { + case CombinerType_MEAN: + return kTfLiteCombinerTypeMean; + case CombinerType_SQRTN: + return kTfLiteCombinerTypeSqrtn; + case CombinerType_SUM: + default: + return kTfLiteCombinerTypeSum; + } + }; + + SafeBuiltinDataAllocator safe_allocator(allocator); + *builtin_data = nullptr; + switch (op_type) { + case BuiltinOperator_ABS: { + return ParseAbs(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_ADD: { + return ParseAdd(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_ARG_MAX: { + return ParseArgMax(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_ARG_MIN: { + return ParseArgMin(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_AVERAGE_POOL_2D: { + return ParsePool(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_CEIL: { + return ParseCeil(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_CONCATENATION: { + return ParseConcatenation(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_CONV_2D: { + return ParseConv2D(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_DEPTHWISE_CONV_2D: { + return ParseDepthwiseConv2D(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_DEQUANTIZE: { + return ParseDequantize(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_FLOOR: { + return ParseFloor(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_FULLY_CONNECTED: { + return ParseFullyConnected(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_GREATER: { + return ParseGreater(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_GREATER_EQUAL: { + return ParseGreaterEqual(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_HARD_SWISH: { + return ParseHardSwish(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_L2_NORMALIZATION: { + return ParseL2Normalization(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_L2_POOL_2D: { + return ParsePool(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_LESS: { + return ParseLess(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_LESS_EQUAL: { + return ParseLessEqual(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_LOG: { + return ParseLog(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_LOGICAL_AND: { + return ParseLogicalAnd(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_LOGICAL_NOT: { + return ParseLogicalNot(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_LOGICAL_OR: { + return ParseLogicalOr(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_LOGISTIC: { + return ParseLogistic(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_MAXIMUM: { + return ParseMaximum(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_MAX_POOL_2D: { + return ParsePool(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_MEAN: { + return ParseReducer(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_MINIMUM: { + return ParseMinimum(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_MUL: { + return ParseMul(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_NEG: { + return ParseNeg(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_NOT_EQUAL: { + return ParseNotEqual(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_PACK: { + return ParsePack(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_PAD: { + return ParsePad(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_PADV2: { + return ParsePadV2(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_PRELU: { + return ParsePrelu(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_QUANTIZE: { + return ParseQuantize(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_ANY: { + return ParseReducer(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_MAX: { + return ParseReducer(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_MIN: { + return ParseReducer(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_REDUCE_PROD: { + return ParseReducer(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_RELU: { + return ParseRelu(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_RELU6: { + return ParseRelu6(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_RESHAPE: { + return ParseReshape(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: { + return ParseResizeNearestNeighbor(op, error_reporter, allocator, + builtin_data); + } + + case BuiltinOperator_ROUND: { + return ParseRound(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_RSQRT: { + return ParseRsqrt(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SIN: { + return ParseSin(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SOFTMAX: { + return ParseSoftmax(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SPLIT: { + return ParseSplit(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SQRT: { + return ParseSqrt(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SQUARE: { + return ParseSquare(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_STRIDED_SLICE: { + return ParseStridedSlice(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SUB: { + return ParseSub(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SUM: { + return ParseReducer(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_SVDF: { + return ParseSvdf(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_TANH: { + return ParseTanh(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_UNPACK: { + return ParseUnpack(op, error_reporter, allocator, builtin_data); + } + + case BuiltinOperator_CAST: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = op->builtin_options_as_CastOptions()) { + TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->in_data_type(), + ¶ms->in_data_type, + error_reporter)); + TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->out_data_type(), + ¶ms->out_data_type, + error_reporter)); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_LSH_PROJECTION: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* lshParams = + op->builtin_options_as_LSHProjectionOptions()) { + params->type = parseLSHProjectionType(lshParams->type()); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* sequence_rnn_params = + op->builtin_options_as_SequenceRNNOptions()) { + params->activation = + ConvertActivation(sequence_rnn_params->fused_activation_function()); + params->time_major = sequence_rnn_params->time_major(); + params->asymmetric_quantize_inputs = + sequence_rnn_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { + auto params = + safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* bidi_sequence_rnn_params = + op->builtin_options_as_BidirectionalSequenceRNNOptions()) { + params->activation = ConvertActivation( + bidi_sequence_rnn_params->fused_activation_function()); + params->time_major = bidi_sequence_rnn_params->time_major(); + params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); + params->asymmetric_quantize_inputs = + bidi_sequence_rnn_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_RNN: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { + params->activation = + ConvertActivation(rnn_params->fused_activation_function()); + params->asymmetric_quantize_inputs = + rnn_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { + auto params = + safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* embedding_params = + op->builtin_options_as_EmbeddingLookupSparseOptions()) { + params->combiner = parseCombinerType(embedding_params->combiner()); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + + case BuiltinOperator_HASHTABLE_LOOKUP: + // no-op. + return kTfLiteOk; + case BuiltinOperator_DIV: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = op->builtin_options_as_DivOptions()) { + params->activation = + ConvertActivation(schema_params->fused_activation_function()); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = + op->builtin_options_as_LocalResponseNormalizationOptions()) { + params->radius = schema_params->radius(); + params->bias = schema_params->bias(); + params->alpha = schema_params->alpha(); + params->beta = schema_params->beta(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_LSTM: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) { + params->activation = + ConvertActivation(lstm_params->fused_activation_function()); + params->cell_clip = lstm_params->cell_clip(); + params->proj_clip = lstm_params->proj_clip(); + switch (lstm_params->kernel_type()) { + case LSTMKernelType_FULL: + params->kernel_type = kTfLiteLSTMFullKernel; + break; + case LSTMKernelType_BASIC: + params->kernel_type = kTfLiteLSTMBasicKernel; + break; + default: + TF_LITE_REPORT_ERROR(error_reporter, + "Unhandled LSTM kernel type: %d", + lstm_params->kernel_type()); + return kTfLiteError; + } + params->asymmetric_quantize_inputs = + lstm_params->asymmetric_quantize_inputs(); + } else { + TF_LITE_REPORT_ERROR(error_reporter, + "No valid LSTM builtin options exist"); + return kTfLiteError; + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { + auto params = + safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* seq_lstm_params = + op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { + params->activation = + ConvertActivation(seq_lstm_params->fused_activation_function()); + params->cell_clip = seq_lstm_params->cell_clip(); + params->proj_clip = seq_lstm_params->proj_clip(); + params->time_major = seq_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + seq_lstm_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { + auto params = + safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* bidi_lstm_params = + op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { + params->activation = + ConvertActivation(bidi_lstm_params->fused_activation_function()); + params->cell_clip = bidi_lstm_params->cell_clip(); + params->proj_clip = bidi_lstm_params->proj_clip(); + params->merge_outputs = bidi_lstm_params->merge_outputs(); + params->time_major = bidi_lstm_params->time_major(); + params->asymmetric_quantize_inputs = + bidi_lstm_params->asymmetric_quantize_inputs(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_RESIZE_BILINEAR: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = + op->builtin_options_as_ResizeBilinearOptions()) { + params->align_corners = schema_params->align_corners(); + params->half_pixel_centers = schema_params->half_pixel_centers(); + } else { + // Some older models did not populate the ResizeBilinearOptions field in + // the flatbuffer, so ensure it's set to a sensible default. + params->align_corners = false; + params->half_pixel_centers = false; + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_SKIP_GRAM: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* skip_gram_params = + op->builtin_options_as_SkipGramOptions()) { + params->ngram_size = skip_gram_params->ngram_size(); + params->max_skip_size = skip_gram_params->max_skip_size(); + params->include_all_ngrams = skip_gram_params->include_all_ngrams(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_SPACE_TO_DEPTH: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = + op->builtin_options_as_SpaceToDepthOptions()) { + params->block_size = schema_params->block_size(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_DEPTH_TO_SPACE: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = + op->builtin_options_as_DepthToSpaceOptions()) { + params->block_size = schema_params->block_size(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_GATHER: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + params->axis = 0; + if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { + params->axis = gather_params->axis(); + } + + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_SPLIT_V: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = op->builtin_options_as_SplitVOptions()) { + params->num_splits = schema_params->num_splits(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_SQUEEZE: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = op->builtin_options_as_SqueezeOptions()) { + const auto* squeeze_dims = schema_params->squeeze_dims(); + TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( + sizeof(params->squeeze_dims), squeeze_dims, params->squeeze_dims, + error_reporter, "squeeze")); + params->num_squeeze_dims = squeeze_dims->size(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_TRANSPOSE_CONV: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* transpose_conv_params = + op->builtin_options_as_TransposeConvOptions()) { + params->padding = ConvertPadding(transpose_conv_params->padding()); + params->stride_width = transpose_conv_params->stride_w(); + params->stride_height = transpose_conv_params->stride_h(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_SPARSE_TO_DENSE: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* sparse_to_dense_params = + op->builtin_options_as_SparseToDenseOptions()) { + params->validate_indices = sparse_to_dense_params->validate_indices(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_SHAPE: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = op->builtin_options_as_ShapeOptions()) { + TF_LITE_ENSURE_STATUS(ConvertTensorType( + schema_params->out_type(), ¶ms->out_type, error_reporter)); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_DELEGATE: { + // TODO(ycling): Revisit when supporting saving delegated models. + TF_LITE_REPORT_ERROR(error_reporter, + "DELEGATE op shouldn't exist in model."); + return kTfLiteError; + } + case BuiltinOperator_FAKE_QUANT: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = + op->builtin_options_as_FakeQuantOptions()) { + params->min = schema_params->min(); + params->max = schema_params->max(); + params->num_bits = schema_params->num_bits(); + params->narrow_range = schema_params->narrow_range(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_ONE_HOT: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* schema_params = op->builtin_options_as_OneHotOptions()) { + params->axis = schema_params->axis(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_LEAKY_RELU: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* leaky_relu_params = + op->builtin_options_as_LeakyReluOptions()) { + params->alpha = leaky_relu_params->alpha(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_MIRROR_PAD: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + const auto* mirror_pad_params = op->builtin_options_as_MirrorPadOptions(); + if (mirror_pad_params != nullptr) { + params->mode = + mirror_pad_params->mode() == tflite::MirrorPadMode_REFLECT + ? TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect + : TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingSymmetric; + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_UNIQUE: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + const auto* unique_params = op->builtin_options_as_UniqueOptions(); + if (unique_params != nullptr) { + params->index_out_type = + unique_params->idx_out_type() == tflite::TensorType_INT64 + ? TfLiteType::kTfLiteInt64 + : TfLiteType::kTfLiteInt32; + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_REVERSE_SEQUENCE: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* reverse_seq_params = + op->builtin_options_as_ReverseSequenceOptions()) { + params->seq_dim = reverse_seq_params->seq_dim(); + params->batch_dim = reverse_seq_params->batch_dim(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_IF: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* if_params = op->builtin_options_as_IfOptions()) { + params->then_subgraph_index = if_params->then_subgraph_index(); + params->else_subgraph_index = if_params->else_subgraph_index(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_WHILE: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* while_params = op->builtin_options_as_WhileOptions()) { + params->cond_subgraph_index = while_params->cond_subgraph_index(); + params->body_subgraph_index = while_params->body_subgraph_index(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + case BuiltinOperator_BATCH_MATMUL: { + auto params = safe_allocator.Allocate(); + TF_LITE_ENSURE(error_reporter, params != nullptr); + if (const auto* bmm_params = + op->builtin_options_as_BatchMatMulOptions()) { + params->adj_x = bmm_params->adj_x(); + params->adj_y = bmm_params->adj_y(); + } + *builtin_data = params.release(); + return kTfLiteOk; + } + // Below are the ops with no builtin_data structure. + case BuiltinOperator_BATCH_TO_SPACE_ND: + // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are + // ok for now, since there is no call implementation either. + case BuiltinOperator_CALL: + case BuiltinOperator_CONCAT_EMBEDDINGS: + case BuiltinOperator_COS: + case BuiltinOperator_CUSTOM: + case BuiltinOperator_ELU: + case BuiltinOperator_EMBEDDING_LOOKUP: + case BuiltinOperator_EQUAL: + case BuiltinOperator_EXP: + case BuiltinOperator_EXPAND_DIMS: + case BuiltinOperator_LOG_SOFTMAX: + case BuiltinOperator_MATRIX_DIAG: + case BuiltinOperator_MATRIX_SET_DIAG: + case BuiltinOperator_RELU_N1_TO_1: + case BuiltinOperator_SELECT: + case BuiltinOperator_SELECT_V2: + case BuiltinOperator_SLICE: + case BuiltinOperator_SPACE_TO_BATCH_ND: + case BuiltinOperator_TILE: + case BuiltinOperator_TOPK_V2: + case BuiltinOperator_TRANSPOSE: + case BuiltinOperator_POW: + case BuiltinOperator_FLOOR_DIV: + case BuiltinOperator_ZEROS_LIKE: + case BuiltinOperator_FILL: + case BuiltinOperator_FLOOR_MOD: + case BuiltinOperator_RANGE: + case BuiltinOperator_SQUARED_DIFFERENCE: + case BuiltinOperator_REVERSE_V2: + case BuiltinOperator_ADD_N: + case BuiltinOperator_GATHER_ND: + case BuiltinOperator_WHERE: + case BuiltinOperator_RANK: + case BuiltinOperator_NON_MAX_SUPPRESSION_V4: + case BuiltinOperator_NON_MAX_SUPPRESSION_V5: + case BuiltinOperator_SCATTER_ND: + case BuiltinOperator_DENSIFY: + case BuiltinOperator_SEGMENT_SUM: + return kTfLiteOk; + } + return kTfLiteError; +} // NOLINT[readability/fn_size] + } // namespace TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type, @@ -1007,693 +1700,34 @@ TfLiteStatus ParseUnpack(const Operator* op, ErrorReporter* error_reporter, TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type, ErrorReporter* error_reporter, BuiltinDataAllocator* allocator, void** builtin_data) { - auto parseLSHProjectionType = [](LSHProjectionType type) { - switch (type) { - case LSHProjectionType_SPARSE: - return kTfLiteLshProjectionSparse; - case LSHProjectionType_DENSE: - return kTfLiteLshProjectionDense; - default: - return kTfLiteLshProjectionUnknown; - } - }; - auto parseCombinerType = [](CombinerType type) { - switch (type) { - case CombinerType_MEAN: - return kTfLiteCombinerTypeMean; - case CombinerType_SQRTN: - return kTfLiteCombinerTypeSqrtn; - case CombinerType_SUM: - default: - return kTfLiteCombinerTypeSum; - } - }; - - SafeBuiltinDataAllocator safe_allocator(allocator); - *builtin_data = nullptr; - switch (op_type) { - case BuiltinOperator_ABS: { - return ParseAbs(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_ADD: { - return ParseAdd(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_ARG_MAX: { - return ParseArgMax(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_ARG_MIN: { - return ParseArgMin(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_AVERAGE_POOL_2D: { - return ParsePool(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_CEIL: { - return ParseCeil(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_CONCATENATION: { - return ParseConcatenation(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_CONV_2D: { - return ParseConv2D(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_DEPTHWISE_CONV_2D: { - return ParseDepthwiseConv2D(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_DEQUANTIZE: { - return ParseDequantize(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_FLOOR: { - return ParseFloor(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_FULLY_CONNECTED: { - return ParseFullyConnected(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_GREATER: { - return ParseGreater(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_GREATER_EQUAL: { - return ParseGreaterEqual(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_HARD_SWISH: { - return ParseHardSwish(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_L2_NORMALIZATION: { - return ParseL2Normalization(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_L2_POOL_2D: { - return ParsePool(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_LESS: { - return ParseLess(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_LESS_EQUAL: { - return ParseLessEqual(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_LOG: { - return ParseLog(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_LOGICAL_AND: { - return ParseLogicalAnd(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_LOGICAL_NOT: { - return ParseLogicalNot(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_LOGICAL_OR: { - return ParseLogicalOr(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_LOGISTIC: { - return ParseLogistic(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_MAXIMUM: { - return ParseMaximum(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_MAX_POOL_2D: { - return ParsePool(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_MEAN: { - return ParseReducer(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_MINIMUM: { - return ParseMinimum(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_MUL: { - return ParseMul(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_NEG: { - return ParseNeg(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_NOT_EQUAL: { - return ParseNotEqual(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_PACK: { - return ParsePack(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_PAD: { - return ParsePad(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_PADV2: { - return ParsePadV2(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_PRELU: { - return ParsePrelu(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_QUANTIZE: { - return ParseQuantize(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_ANY: { - return ParseReducer(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_MAX: { - return ParseReducer(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_MIN: { - return ParseReducer(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_REDUCE_PROD: { - return ParseReducer(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_RELU: { - return ParseRelu(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_RELU6: { - return ParseRelu6(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_RESHAPE: { - return ParseReshape(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_RESIZE_NEAREST_NEIGHBOR: { - return ParseResizeNearestNeighbor(op, error_reporter, allocator, - builtin_data); - } - - case BuiltinOperator_ROUND: { - return ParseRound(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_RSQRT: { - return ParseRsqrt(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SIN: { - return ParseSin(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SOFTMAX: { - return ParseSoftmax(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SPLIT: { - return ParseSplit(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SQRT: { - return ParseSqrt(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SQUARE: { - return ParseSquare(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_STRIDED_SLICE: { - return ParseStridedSlice(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SUB: { - return ParseSub(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SUM: { - return ParseReducer(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_SVDF: { - return ParseSvdf(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_TANH: { - return ParseTanh(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_UNPACK: { - return ParseUnpack(op, error_reporter, allocator, builtin_data); - } - - case BuiltinOperator_CAST: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_CastOptions()) { - TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->in_data_type(), - ¶ms->in_data_type, - error_reporter)); - TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->out_data_type(), - ¶ms->out_data_type, - error_reporter)); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_LSH_PROJECTION: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* lshParams = - op->builtin_options_as_LSHProjectionOptions()) { - params->type = parseLSHProjectionType(lshParams->type()); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* sequence_rnn_params = - op->builtin_options_as_SequenceRNNOptions()) { - params->activation = - ConvertActivation(sequence_rnn_params->fused_activation_function()); - params->time_major = sequence_rnn_params->time_major(); - params->asymmetric_quantize_inputs = - sequence_rnn_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_RNN: { - auto params = - safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* bidi_sequence_rnn_params = - op->builtin_options_as_BidirectionalSequenceRNNOptions()) { - params->activation = ConvertActivation( - bidi_sequence_rnn_params->fused_activation_function()); - params->time_major = bidi_sequence_rnn_params->time_major(); - params->merge_outputs = bidi_sequence_rnn_params->merge_outputs(); - params->asymmetric_quantize_inputs = - bidi_sequence_rnn_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_RNN: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* rnn_params = op->builtin_options_as_RNNOptions()) { - params->activation = - ConvertActivation(rnn_params->fused_activation_function()); - params->asymmetric_quantize_inputs = - rnn_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_EMBEDDING_LOOKUP_SPARSE: { - auto params = - safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* embedding_params = - op->builtin_options_as_EmbeddingLookupSparseOptions()) { - params->combiner = parseCombinerType(embedding_params->combiner()); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - - case BuiltinOperator_HASHTABLE_LOOKUP: - // no-op. - return kTfLiteOk; - case BuiltinOperator_DIV: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_DivOptions()) { - params->activation = - ConvertActivation(schema_params->fused_activation_function()); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_LOCAL_RESPONSE_NORMALIZATION: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_LocalResponseNormalizationOptions()) { - params->radius = schema_params->radius(); - params->bias = schema_params->bias(); - params->alpha = schema_params->alpha(); - params->beta = schema_params->beta(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_LSTM: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* lstm_params = op->builtin_options_as_LSTMOptions()) { - params->activation = - ConvertActivation(lstm_params->fused_activation_function()); - params->cell_clip = lstm_params->cell_clip(); - params->proj_clip = lstm_params->proj_clip(); - switch (lstm_params->kernel_type()) { - case LSTMKernelType_FULL: - params->kernel_type = kTfLiteLSTMFullKernel; - break; - case LSTMKernelType_BASIC: - params->kernel_type = kTfLiteLSTMBasicKernel; - break; - default: - TF_LITE_REPORT_ERROR(error_reporter, - "Unhandled LSTM kernel type: %d", - lstm_params->kernel_type()); - return kTfLiteError; - } - params->asymmetric_quantize_inputs = - lstm_params->asymmetric_quantize_inputs(); - } else { - TF_LITE_REPORT_ERROR(error_reporter, - "No valid LSTM builtin options exist"); - return kTfLiteError; - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_LSTM: { - auto params = - safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* seq_lstm_params = - op->builtin_options_as_UnidirectionalSequenceLSTMOptions()) { - params->activation = - ConvertActivation(seq_lstm_params->fused_activation_function()); - params->cell_clip = seq_lstm_params->cell_clip(); - params->proj_clip = seq_lstm_params->proj_clip(); - params->time_major = seq_lstm_params->time_major(); - params->asymmetric_quantize_inputs = - seq_lstm_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_BIDIRECTIONAL_SEQUENCE_LSTM: { - auto params = - safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* bidi_lstm_params = - op->builtin_options_as_BidirectionalSequenceLSTMOptions()) { - params->activation = - ConvertActivation(bidi_lstm_params->fused_activation_function()); - params->cell_clip = bidi_lstm_params->cell_clip(); - params->proj_clip = bidi_lstm_params->proj_clip(); - params->merge_outputs = bidi_lstm_params->merge_outputs(); - params->time_major = bidi_lstm_params->time_major(); - params->asymmetric_quantize_inputs = - bidi_lstm_params->asymmetric_quantize_inputs(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_RESIZE_BILINEAR: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_ResizeBilinearOptions()) { - params->align_corners = schema_params->align_corners(); - params->half_pixel_centers = schema_params->half_pixel_centers(); - } else { - // Some older models did not populate the ResizeBilinearOptions field in - // the flatbuffer, so ensure it's set to a sensible default. - params->align_corners = false; - params->half_pixel_centers = false; - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_SKIP_GRAM: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* skip_gram_params = - op->builtin_options_as_SkipGramOptions()) { - params->ngram_size = skip_gram_params->ngram_size(); - params->max_skip_size = skip_gram_params->max_skip_size(); - params->include_all_ngrams = skip_gram_params->include_all_ngrams(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_SPACE_TO_DEPTH: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_SpaceToDepthOptions()) { - params->block_size = schema_params->block_size(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_DEPTH_TO_SPACE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_DepthToSpaceOptions()) { - params->block_size = schema_params->block_size(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_GATHER: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - params->axis = 0; - if (const auto* gather_params = op->builtin_options_as_GatherOptions()) { - params->axis = gather_params->axis(); - } - - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_SPLIT_V: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_SplitVOptions()) { - params->num_splits = schema_params->num_splits(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_SQUEEZE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_SqueezeOptions()) { - const auto* squeeze_dims = schema_params->squeeze_dims(); - TF_LITE_ENSURE_STATUS(FlatBufferIntVectorToArray( - sizeof(params->squeeze_dims), squeeze_dims, params->squeeze_dims, - error_reporter, "squeeze")); - params->num_squeeze_dims = squeeze_dims->size(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_TRANSPOSE_CONV: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* transpose_conv_params = - op->builtin_options_as_TransposeConvOptions()) { - params->padding = ConvertPadding(transpose_conv_params->padding()); - params->stride_width = transpose_conv_params->stride_w(); - params->stride_height = transpose_conv_params->stride_h(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_SPARSE_TO_DENSE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* sparse_to_dense_params = - op->builtin_options_as_SparseToDenseOptions()) { - params->validate_indices = sparse_to_dense_params->validate_indices(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_SHAPE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_ShapeOptions()) { - TF_LITE_ENSURE_STATUS(ConvertTensorType( - schema_params->out_type(), ¶ms->out_type, error_reporter)); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_DELEGATE: { - // TODO(ycling): Revisit when supporting saving delegated models. - TF_LITE_REPORT_ERROR(error_reporter, - "DELEGATE op shouldn't exist in model."); - return kTfLiteError; - } - case BuiltinOperator_FAKE_QUANT: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = - op->builtin_options_as_FakeQuantOptions()) { - params->min = schema_params->min(); - params->max = schema_params->max(); - params->num_bits = schema_params->num_bits(); - params->narrow_range = schema_params->narrow_range(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_ONE_HOT: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* schema_params = op->builtin_options_as_OneHotOptions()) { - params->axis = schema_params->axis(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_LEAKY_RELU: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* leaky_relu_params = - op->builtin_options_as_LeakyReluOptions()) { - params->alpha = leaky_relu_params->alpha(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_MIRROR_PAD: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - const auto* mirror_pad_params = op->builtin_options_as_MirrorPadOptions(); - if (mirror_pad_params != nullptr) { - params->mode = - mirror_pad_params->mode() == tflite::MirrorPadMode_REFLECT - ? TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingReflect - : TfLiteMirrorPaddingMode::kTfLiteMirrorPaddingSymmetric; - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_UNIQUE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - const auto* unique_params = op->builtin_options_as_UniqueOptions(); - if (unique_params != nullptr) { - params->index_out_type = - unique_params->idx_out_type() == tflite::TensorType_INT64 - ? TfLiteType::kTfLiteInt64 - : TfLiteType::kTfLiteInt32; - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_REVERSE_SEQUENCE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* reverse_seq_params = - op->builtin_options_as_ReverseSequenceOptions()) { - params->seq_dim = reverse_seq_params->seq_dim(); - params->batch_dim = reverse_seq_params->batch_dim(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_IF: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* if_params = op->builtin_options_as_IfOptions()) { - params->then_subgraph_index = if_params->then_subgraph_index(); - params->else_subgraph_index = if_params->else_subgraph_index(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_WHILE: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* while_params = op->builtin_options_as_WhileOptions()) { - params->cond_subgraph_index = while_params->cond_subgraph_index(); - params->body_subgraph_index = while_params->body_subgraph_index(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - case BuiltinOperator_BATCH_MATMUL: { - auto params = safe_allocator.Allocate(); - TF_LITE_ENSURE(error_reporter, params != nullptr); - if (const auto* bmm_params = - op->builtin_options_as_BatchMatMulOptions()) { - params->adj_x = bmm_params->adj_x(); - params->adj_y = bmm_params->adj_y(); - } - *builtin_data = params.release(); - return kTfLiteOk; - } - // Below are the ops with no builtin_data structure. - case BuiltinOperator_BATCH_TO_SPACE_ND: - // TODO(aselle): Implement call in BuiltinOptions, but nullptrs are - // ok for now, since there is no call implementation either. - case BuiltinOperator_CALL: - case BuiltinOperator_CONCAT_EMBEDDINGS: - case BuiltinOperator_COS: - case BuiltinOperator_CUSTOM: - case BuiltinOperator_ELU: - case BuiltinOperator_EMBEDDING_LOOKUP: - case BuiltinOperator_EQUAL: - case BuiltinOperator_EXP: - case BuiltinOperator_EXPAND_DIMS: - case BuiltinOperator_LOG_SOFTMAX: - case BuiltinOperator_MATRIX_DIAG: - case BuiltinOperator_MATRIX_SET_DIAG: - case BuiltinOperator_RELU_N1_TO_1: - case BuiltinOperator_SELECT: - case BuiltinOperator_SELECT_V2: - case BuiltinOperator_SLICE: - case BuiltinOperator_SPACE_TO_BATCH_ND: - case BuiltinOperator_TILE: - case BuiltinOperator_TOPK_V2: - case BuiltinOperator_TRANSPOSE: - case BuiltinOperator_POW: - case BuiltinOperator_FLOOR_DIV: - case BuiltinOperator_ZEROS_LIKE: - case BuiltinOperator_FILL: - case BuiltinOperator_FLOOR_MOD: - case BuiltinOperator_RANGE: - case BuiltinOperator_SQUARED_DIFFERENCE: - case BuiltinOperator_REVERSE_V2: - case BuiltinOperator_ADD_N: - case BuiltinOperator_GATHER_ND: - case BuiltinOperator_WHERE: - case BuiltinOperator_RANK: - case BuiltinOperator_NON_MAX_SUPPRESSION_V4: - case BuiltinOperator_NON_MAX_SUPPRESSION_V5: - case BuiltinOperator_SCATTER_ND: - case BuiltinOperator_DENSIFY: - case BuiltinOperator_SEGMENT_SUM: - return kTfLiteOk; - } +// TODO(b/145762662): It would be preferable to have the build graph for TF Lite +// Micro not have the ParseOpData function at all. This would require splitting +// the current file into two separate files, one of which defines the +// ParseOpData function and the other that defines the operator specific parse +// functions (e.g. ParseAdd). +// +// Such a split was attempted but was not worth the effort at the time because +// of the following reasons: +// * We could either duplicate the functions and the SafeBuiltinDataAllocator +// class in the anonymous namespace of this file, or attempt to make a common +// library with these helper functions and class. +// * Making a common library with a separate build target was not feasible as +// it introduced circular dependencies due to the ErrorReporter and a common +// .cc and .h within the same api build target the also cause circular +// dependencies due to the BuiltinDataAllocator class. +// * If all the builtin operators were to have their own parse functions, or we +// were ok with some amount of code duplication, then this split of the .cc +// files would be a lot more feasible. +#ifdef TF_LITE_STATIC_MEMORY + TF_LITE_REPORT_ERROR( + error_reporter, + "ParseOpData is unsupported on TfLiteMicro, please use the operator " + "specific parse functions (e.g. ParseAdd etc.).\n"); return kTfLiteError; -} // NOLINT[readability/fn_size] +#else + return ParseOpDataTfLite(op, op_type, error_reporter, allocator, + builtin_data); +#endif +} } // namespace tflite From 4a8aa32dbdd4808a8f704f56f5ef9963be1a14f6 Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Wed, 8 Jul 2020 11:31:12 -0700 Subject: [PATCH 68/88] Run gpu doctest with a fixed set of devices This is to ensure that the test can run with different physical GPUs, and we can change our examples to use two GPUs which are more real. PiperOrigin-RevId: 320229497 Change-Id: I0bc875ab3a54f32ebc6a705c7c350795b341b31f --- .../python/distribute/distribute_lib.py | 37 ++++---- tensorflow/python/distribute/input_lib.py | 86 ++++++------------- .../python/distribute/mirrored_strategy.py | 19 ++-- tensorflow/python/distribute/values.py | 16 ++-- tensorflow/tools/docs/BUILD | 27 ++++-- tensorflow/tools/docs/tf_doctest.py | 42 +++++++++ 6 files changed, 128 insertions(+), 99 deletions(-) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index 78a199fe782..e82e60eddb6 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -533,7 +533,7 @@ class ValueContext(object): 2. Passed in by `experimental_distribute_values_from_function`. - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def value_fn(value_context): ... return value_context.num_replicas_in_sync >>> distributed_values = ( @@ -541,7 +541,7 @@ class ValueContext(object): ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result - (1,) + (2, 2) """ @@ -792,13 +792,14 @@ class StrategyBase(object): This method returns a context manager, and is used as follows: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> # Variable created inside scope: >>> with strategy.scope(): ... mirrored_variable = tf.Variable(1.) >>> mirrored_variable MirroredVariable:{ - 0: + 0: , + 1: } >>> # Variable created outside scope: >>> regular_variable = tf.Variable(1.) @@ -1157,18 +1158,21 @@ class StrategyBase(object): 1. Constant tensor input. - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> tensor_input = tf.constant(3.0) >>> @tf.function ... def replica_fn(input): ... return input*2.0 >>> result = strategy.run(replica_fn, args=(tensor_input,)) >>> result - + PerReplica:{ + 0: , + 1: + } 2. DistributedValues input. - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> @tf.function ... def run(): ... def value_fn(value_context): @@ -1181,7 +1185,7 @@ class StrategyBase(object): ... return strategy.run(replica_fn2, args=(distributed_values,)) >>> result = run() >>> result - + Args: fn: The function to run. The output must be a `tf.nest` of `Tensor`s. @@ -1218,7 +1222,7 @@ class StrategyBase(object): def reduce(self, reduce_op, value, axis): """Reduce `value` across replicas and return result on current device. - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def step_fn(): ... i = tf.distribute.get_replica_context().replica_id_in_sync_group ... return tf.identity(i) @@ -1226,7 +1230,7 @@ class StrategyBase(object): >>> per_replica_result = strategy.run(step_fn) >>> total = strategy.reduce("SUM", per_replica_result, axis=None) >>> total - + To see how this would look with multiple replicas, consider the same example with MirroredStrategy with 2 GPUs: @@ -1749,7 +1753,7 @@ class Strategy(StrategyBase): 1. Return constant value per replica: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def value_fn(ctx): ... return tf.constant(1.) >>> distributed_values = ( @@ -1757,11 +1761,12 @@ class Strategy(StrategyBase): ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result - (,) + (, + ) 2. Distribute values in array based on replica_id: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> array_value = np.array([3., 2., 1.]) >>> def value_fn(ctx): ... return array_value[ctx.replica_id_in_sync_group] @@ -1770,11 +1775,11 @@ class Strategy(StrategyBase): ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result - (3.0,) + (3.0, 2.0) 3. Specify values using num_replicas_in_sync: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> def value_fn(ctx): ... return ctx.num_replicas_in_sync >>> distributed_values = ( @@ -1782,7 +1787,7 @@ class Strategy(StrategyBase): ... value_fn)) >>> local_result = strategy.experimental_local_results(distributed_values) >>> local_result - (1,) + (2, 2) 4. Place values on devices and distribute: diff --git a/tensorflow/python/distribute/input_lib.py b/tensorflow/python/distribute/input_lib.py index 64089e54bfa..dc1eeb38f8e 100644 --- a/tensorflow/python/distribute/input_lib.py +++ b/tensorflow/python/distribute/input_lib.py @@ -165,7 +165,7 @@ class DistributedIteratorInterface(collections.Iterator, Example use: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.range(100).batch(2) >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) >>> dist_dataset_iterator = iter(dist_dataset) @@ -176,18 +176,8 @@ class DistributedIteratorInterface(collections.Iterator, >>> for _ in range(step_num): ... strategy.run(one_step, args=(dist_dataset_iterator.get_next(),)) >>> strategy.experimental_local_results(dist_dataset_iterator.get_next()) - (,) - - The above example corresponds to the case where you have only one device. If - you have two devices, for example, - ```python - strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']) - ``` - Then the final line will print out: - ```python (, ) - ``` Returns: A single `tf.Tensor` or a `tf.distribute.DistributedValues` which contains @@ -207,25 +197,14 @@ class DistributedIteratorInterface(collections.Iterator, Example usage: >>> global_batch_size = 16 - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> distributed_iterator.element_spec - (TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), - TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)) - - The above example corresponds to the case where you have only one device. If - you have two devices, for example, - ```python - strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']) - ``` - Then the final line will print out: - ```python (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) - ``` Returns: A nested structure of `tf.TypeSpec` objects matching the structure of an @@ -244,7 +223,7 @@ class DistributedIteratorInterface(collections.Iterator, Example usage: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> global_batch_size = 2 >>> steps_per_loop = 2 >>> dataset = tf.data.Dataset.range(10).batch(global_batch_size) @@ -312,8 +291,8 @@ class DistributedDatasetInterface(collections.Iterable, * use a pythonic for-loop construct: - >>> global_batch_size = 2 - >>> strategy = tf.distribute.MirroredStrategy() + >>> global_batch_size = 4 + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(4).batch(global_batch_size) >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) >>> @tf.function @@ -324,12 +303,14 @@ class DistributedDatasetInterface(collections.Iterable, ... # train_step trains the model using the dataset elements ... loss = strategy.run(train_step, args=(x,)) ... print("Loss is", loss) - Loss is tf.Tensor( - [[0.7] - [0.7]], shape=(2, 1), dtype=float32) - Loss is tf.Tensor( + Loss is PerReplica:{ + 0: tf.Tensor( + [[0.7] + [0.7]], shape=(2, 1), dtype=float32), + 1: tf.Tensor( [[0.7] [0.7]], shape=(2, 1), dtype=float32) + } Placing the loop inside a `tf.function` will give a performance boost. However `break` and `return` are currently not supported if the loop is @@ -342,7 +323,7 @@ class DistributedDatasetInterface(collections.Iterable, `tf.distribute.DistributedIterator` >>> global_batch_size = 4 - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> train_dataset = tf.data.Dataset.from_tensors(([1.],[1.])).repeat(50).batch(global_batch_size) >>> train_dist_dataset = strategy.experimental_distribute_dataset(train_dataset) >>> @tf.function @@ -362,10 +343,10 @@ class DistributedDatasetInterface(collections.Iterable, ... total_loss += distributed_train_step(next(dist_dataset_iterator)) ... num_batches += 1 ... average_train_loss = total_loss / num_batches - ... template = ("Epoch {}, Loss: {}") + ... template = ("Epoch {}, Loss: {:.4f}") ... print (template.format(epoch+1, average_train_loss)) - Epoch 1, Loss: 0.10000000894069672 - Epoch 2, Loss: 0.10000000894069672 + Epoch 1, Loss: 0.2000 + Epoch 2, Loss: 0.2000 To achieve a performance improvement, you can also wrap the `strategy.run` @@ -389,10 +370,10 @@ class DistributedDatasetInterface(collections.Iterable, For example: - >>> global_batch_size = 2 + >>> global_batch_size = 4 >>> epochs = 1 >>> steps_per_epoch = 1 - >>> mirrored_strategy = tf.distribute.MirroredStrategy() + >>> mirrored_strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensors(([2.])).repeat(100).batch(global_batch_size) >>> dist_dataset = mirrored_strategy.experimental_distribute_dataset(dataset) >>> @tf.function(input_signature=[dist_dataset.element_spec]) @@ -405,9 +386,14 @@ class DistributedDatasetInterface(collections.Iterable, ... for _ in range(steps_per_epoch): ... output = train_step(next(iterator)) ... print(output) - tf.Tensor( + PerReplica:{ + 0: tf.Tensor( + [[4.] + [4.]], shape=(2, 1), dtype=float32), + 1: tf.Tensor( [[4.] [4.]], shape=(2, 1), dtype=float32) + } Visit the [tutorial](https://www.tensorflow.org/tutorials/distribute/input) @@ -422,25 +408,14 @@ class DistributedDatasetInterface(collections.Iterable, Example usage: >>> global_batch_size = 4 - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensor_slices([1, 2, 3, 4]).repeat().batch(global_batch_size) >>> distributed_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> print(next(distributed_iterator)) - tf.Tensor([1 2 3 4], shape=(4,), dtype=int32) - - - The above example corresponds to the case where you have only one device. If - you have two devices, for example, - ```python - strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']) - ``` - Then the final line will print out: - ```python PerReplica:{ 0: tf.Tensor([1 2], shape=(2,), dtype=int32), 1: tf.Tensor([3 4], shape=(2,), dtype=int32) } - ``` Returns: An `tf.distribute.DistributedIterator` instance for the given @@ -456,25 +431,14 @@ class DistributedDatasetInterface(collections.Iterable, Example usage: >>> global_batch_size = 16 - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensors(([1.],[2])).repeat(100).batch(global_batch_size) >>> dist_dataset = strategy.experimental_distribute_dataset(dataset) >>> dist_dataset.element_spec - (TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), - TensorSpec(shape=(None, 1), dtype=tf.int32, name=None)) - - The above example corresponds to the case where you have only one device. If - you have two devices, for example, - ```python - strategy = tf.distribute.MirroredStrategy(['/gpu:0', '/gpu:1']) - ``` - Then the final line will print out: - ```python (PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.float32, name=None), TensorSpec(shape=(None, 1), dtype=tf.float32, name=None)), PerReplicaSpec(TensorSpec(shape=(None, 1), dtype=tf.int32, name=None), TensorSpec(shape=(None, 1), dtype=tf.int32, name=None))) - ``` Returns: A nested structure of `tf.TypeSpec` objects matching the structure of an diff --git a/tensorflow/python/distribute/mirrored_strategy.py b/tensorflow/python/distribute/mirrored_strategy.py index 91efce92793..b424f798476 100644 --- a/tensorflow/python/distribute/mirrored_strategy.py +++ b/tensorflow/python/distribute/mirrored_strategy.py @@ -199,13 +199,14 @@ class MirroredStrategy(distribute_lib.Strategy): will use the available CPUs. Note that TensorFlow treats all CPUs on a machine as a single device, and uses threads internally for parallelism. - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> with strategy.scope(): ... x = tf.Variable(1.) >>> x MirroredVariable:{ - 0: - } + 0: , + 1: + } While using distribution strategies, all the variable creation should be done within the strategy's scope. This will replicate the variables across all the @@ -219,13 +220,15 @@ class MirroredStrategy(distribute_lib.Strategy): ... def create_variable(): ... if not x: ... x.append(tf.Variable(1.)) - >>> strategy = tf.distribute.MirroredStrategy() + ... return x[0] + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> with strategy.scope(): - ... create_variable() - ... print (x[0]) + ... _ = create_variable() + ... print(x[0]) MirroredVariable:{ - 0: - } + 0: , + 1: + } `experimental_distribute_dataset` can be used to distribute the dataset across the replicas when writing your own training loop. If you are using `.fit` and diff --git a/tensorflow/python/distribute/values.py b/tensorflow/python/distribute/values.py index 399e8b80a19..087619aadd8 100644 --- a/tensorflow/python/distribute/values.py +++ b/tensorflow/python/distribute/values.py @@ -93,14 +93,14 @@ class DistributedValues(object): 1. Created from a `tf.distribute.DistributedDataset`: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> distributed_values = next(dataset_iterator) 2. Returned by `run`: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> @tf.function ... def run(): ... ctx = tf.distribute.get_replica_context() @@ -109,7 +109,7 @@ class DistributedValues(object): 3. As input into `run`: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> distributed_values = next(dataset_iterator) @@ -120,7 +120,7 @@ class DistributedValues(object): 4. Reduce value: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> distributed_values = next(dataset_iterator) @@ -128,16 +128,16 @@ class DistributedValues(object): ... distributed_values, ... axis = 0) - 5. Inspect per replica values: + 5. Inspect local replica values: - >>> strategy = tf.distribute.MirroredStrategy() + >>> strategy = tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"]) >>> dataset = tf.data.Dataset.from_tensor_slices([5., 6., 7., 8.]).batch(2) >>> dataset_iterator = iter(strategy.experimental_distribute_dataset(dataset)) >>> per_replica_values = strategy.experimental_local_results( ... distributed_values) >>> per_replica_values - (,) + (, + ) """ diff --git a/tensorflow/tools/docs/BUILD b/tensorflow/tools/docs/BUILD index 21a227e03f9..9814059be08 100644 --- a/tensorflow/tools/docs/BUILD +++ b/tensorflow/tools/docs/BUILD @@ -3,6 +3,10 @@ load("//tensorflow:tensorflow.bzl", "py_test") load("//tensorflow/python/tpu:tpu.bzl", "tpu_py_test") +load( + "//tensorflow/core/platform:build_config_root.bzl", + "tf_gpu_tests_tags", +) package( default_visibility = ["//tensorflow:__subpackages__"], @@ -11,7 +15,18 @@ package( exports_files(["LICENSE"]) -tpu_module = "tpu.,distribute.tpu_strategy,distribute.cluster_resolver.tpu,distribute.cluster_resolver.tpu_oss" +tpu_module = [ + "tpu.", + "distribute.tpu_strategy", + "distribute.cluster_resolver.tpu", + "distribute.cluster_resolver.tpu_oss", +] + +# tf.distribute docstring often uses GPU, so they're only covered in +# tf_doctest_gpu. +distribute_module = [ + "distribute.", +] py_library( name = "tf_doctest_lib", @@ -25,7 +40,7 @@ py_library( py_test( name = "tf_doctest", srcs = ["tf_doctest.py"], - args = ["--module_prefix_skip=" + tpu_module], + args = ["--module_prefix_skip=" + ",".join(tpu_module + distribute_module)], python_version = "PY3", tags = [ "no_oss_py2", @@ -46,7 +61,7 @@ py_test( tpu_py_test( name = "tf_doctest_tpu", srcs = ["tf_doctest.py"], - args = ["--module=" + tpu_module], + args = ["--module=" + ",".join(tpu_module)], disable_experimental = True, disable_v3 = True, main = "tf_doctest.py", @@ -70,7 +85,8 @@ py_test( srcs = ["tf_doctest.py"], args = [ "--module=distribute.", - "--module_prefix_skip=" + tpu_module, + "--module_prefix_skip=" + ",".join(tpu_module), + "--required_gpus=2", ], main = "tf_doctest.py", python_version = "PY3", @@ -82,8 +98,7 @@ py_test( "noasan", "nomsan", "notsan", - "requires-gpu-nvidia", - ], + ] + tf_gpu_tests_tags(), deps = [ ":tf_doctest_lib", "//tensorflow:tensorflow_py", diff --git a/tensorflow/tools/docs/tf_doctest.py b/tensorflow/tools/docs/tf_doctest.py index 00bf3492787..40b06c6c53f 100644 --- a/tensorflow/tools/docs/tf_doctest.py +++ b/tensorflow/tools/docs/tf_doctest.py @@ -46,6 +46,8 @@ flags.DEFINE_list('module_prefix_skip', [], flags.DEFINE_boolean('list', None, 'List all the modules in the core package imported.') flags.DEFINE_string('file', None, 'A specific file to run doctest on.') +flags.DEFINE_integer('required_gpus', 0, + 'The number of GPUs required for the tests.') flags.mark_flags_as_mutual_exclusive(['module', 'file']) flags.mark_flags_as_mutual_exclusive(['list', 'file']) @@ -128,6 +130,38 @@ def get_module_and_inject_docstring(file_path): return [file_module] +def setup_gpu(required_gpus): + """Sets up the GPU devices. + + If there're more available GPUs than needed, it hides the additional ones. If + there're less, it creates logical devices. This is to make sure the tests see + a fixed number of GPUs regardless of the environment. + + Args: + required_gpus: an integer. The number of GPUs required. + + Raises: + ValueError: if num_gpus is larger than zero but no GPU is available. + """ + if required_gpus == 0: + return + available_gpus = tf.config.experimental.list_physical_devices('GPU') + if not available_gpus: + raise ValueError('requires at least one physical GPU') + if len(available_gpus) >= required_gpus: + tf.config.set_visible_devices(available_gpus[:required_gpus]) + else: + # Create logical GPUs out of one physical GPU for simplicity. Note that the + # other physical GPUs are still available and corresponds to one logical GPU + # each. + num_logical_gpus = required_gpus - len(available_gpus) + 1 + logical_gpus = [ + tf.config.LogicalDeviceConfiguration(memory_limit=256) + for _ in range(num_logical_gpus) + ] + tf.config.set_logical_device_configuration(available_gpus[0], logical_gpus) + + class TfTestCase(tf.test.TestCase): def set_up(self, test): @@ -178,6 +212,14 @@ def load_tests(unused_loader, tests, unused_ignore): )) return tests + +# We can only create logical devices before initializing Tensorflow. This is +# called by unittest framework before running any test. +# https://docs.python.org/3/library/unittest.html#setupmodule-and-teardownmodule +def setUpModule(): + setup_gpu(FLAGS.required_gpus) + + if __name__ == '__main__': recursive_import(tf_root) absltest.main() From afed3a82572d2bc663f1ed98507402f725f184a8 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 8 Jul 2020 11:35:09 -0700 Subject: [PATCH 69/88] Add helper functions to tf_device.replicate op for determining and accessing replicated and packed inputs from block arguments. This exposes reusable functions for determining whether a block argument is forwarding a replicated input or a packed input, and which operand it is forwarding. operand_segment_sizes is also fixed to set packed inputs size correctly in the custom build function. PiperOrigin-RevId: 320230346 Change-Id: I7971aff5f3dfa37278b607b4f8b4ac738eb7659d --- .../compiler/mlir/tensorflow/ir/tf_device.cc | 65 ++++++++++++++++++- .../mlir/tensorflow/ir/tf_device_ops.td | 8 +++ 2 files changed, 71 insertions(+), 2 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc index 67f2b7cb05e..77008b55672 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device.cc @@ -569,8 +569,10 @@ void BuildReplicateOp( // Add derived `operand_segment_sizes` attribute. int32_t num_replicated_inputs = replicated_inputs.size() * n; - auto operand_segment_sizes = DenseIntElementsAttr::get( - VectorType::get({2}, builder->getI32Type()), {num_replicated_inputs, 0}); + int32_t num_packed_inputs = packed_inputs.size(); + auto operand_segment_sizes = + DenseIntElementsAttr::get(VectorType::get({2}, builder->getI32Type()), + {num_replicated_inputs, num_packed_inputs}); state->addAttribute(kOperandSegmentSizesAttr, operand_segment_sizes); for (const auto& output_type : replica_output_types) @@ -600,6 +602,65 @@ void ReplicateOp::build( packed_inputs, replica_output_types); } +// Returns the number of packed block arguments. +unsigned ReplicateOp::GetNumPackedBlockArguments() { + return packed_inputs().size(); +} + +// Returns the number of replicated block arguments. +unsigned ReplicateOp::GetNumReplicatedBlockArguments() { + return GetBody().getNumArguments() - GetNumPackedBlockArguments(); +} + +// Returns the replicated block arguments. A copy should be made if the +// replicate op is being modified. +llvm::ArrayRef ReplicateOp::GetReplicatedBlockArguments() { + return GetBody().getArguments().drop_back(GetNumPackedBlockArguments()); +} + +// Returns the packed block arguments. A copy should be made if the replicate op +// is being modified. +llvm::ArrayRef ReplicateOp::GetPackedBlockArguments() { + return GetBody().getArguments().take_back(GetNumPackedBlockArguments()); +} + +// Checks if a block argument is replicated (forwarding replicated inputs). +bool ReplicateOp::IsReplicatedBlockArgument(BlockArgument block_arg) { + assert(block_arg.getOwner() == &GetBody()); + return block_arg.getArgNumber() < GetNumReplicatedBlockArguments(); +} + +// Checks if a block argument is packed (forwarding a packed input). +bool ReplicateOp::IsPackedBlockArgument(BlockArgument block_arg) { + return !IsReplicatedBlockArgument(block_arg); +} + +// Returns the operand index of the operand being forwarded as a +// replicated/packed block argument for a given replica. This assumes a valid +// block argument (of the replicate op) and a valid replica is provided. +unsigned ReplicateOp::GetReplicaOperandIndexForBlockArgument( + BlockArgument block_arg, unsigned replica) { + const int32_t num_replicas = nAttr().getInt(); + assert(replica < num_replicas && block_arg.getOwner() == &GetBody()); + + const unsigned num_replicated_args = GetNumReplicatedBlockArguments(); + if (block_arg.getArgNumber() < num_replicated_args) + return block_arg.getArgNumber() * num_replicas + replica; + + return block_arg.getArgNumber() - num_replicated_args + + replicated_inputs().size(); +} + +// Returns the operand being forwarded as a replicated/packed block argument for +// a given replica. This assumes a valid block argument (of the replicate op) +// and a valid replica is provided. +Value ReplicateOp::GetReplicaOperandForBlockArgument(BlockArgument block_arg, + unsigned replica) { + const unsigned operand_index = + GetReplicaOperandIndexForBlockArgument(block_arg, replica); + return getOperand(operand_index); +} + //===----------------------------------------------------------------------===// // Canonicalization patterns //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td index 01eea8c94cd..3a92e3237dc 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_device_ops.td @@ -283,6 +283,14 @@ For example: let extraClassDeclaration = [{ Block &GetBody() { return getOperation()->getRegion(0).front(); } + unsigned GetNumReplicatedBlockArguments(); + unsigned GetNumPackedBlockArguments(); + llvm::ArrayRef GetPackedBlockArguments(); + llvm::ArrayRef GetReplicatedBlockArguments(); + bool IsReplicatedBlockArgument(BlockArgument block_arg); + bool IsPackedBlockArgument(BlockArgument block_arg); + unsigned GetReplicaOperandIndexForBlockArgument(BlockArgument block_arg, unsigned replica); + Value GetReplicaOperandForBlockArgument(BlockArgument block_arg, unsigned replica); }]; let builders = [ From 2d18e2be7124814d81369627f73d2ac98feee711 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 11:45:36 -0700 Subject: [PATCH 70/88] Go: Update generated wrapper functions for TensorFlow ops. PiperOrigin-RevId: 320232518 Change-Id: I191e4b25ef8fb243b4180eb4a9a696701490f2cb --- tensorflow/go/op/wrappers.go | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/tensorflow/go/op/wrappers.go b/tensorflow/go/op/wrappers.go index b44b7965040..843ef2fb7e1 100644 --- a/tensorflow/go/op/wrappers.go +++ b/tensorflow/go/op/wrappers.go @@ -41478,13 +41478,13 @@ func ParseExample(scope *Scope, serialized tf.Output, names tf.Output, sparse_ke // DatasetToGraphAttr is an optional argument to DatasetToGraph. type DatasetToGraphAttr func(optionalAttr) -// DatasetToGraphStatefulAllowlist sets the optional stateful_allowlist attribute to value. +// DatasetToGraphStatefulWhitelist sets the optional stateful_whitelist attribute to value. // If not specified, defaults to <> // // REQUIRES: len(value) >= 0 -func DatasetToGraphStatefulAllowlist(value []string) DatasetToGraphAttr { +func DatasetToGraphStatefulWhitelist(value []string) DatasetToGraphAttr { return func(m optionalAttr) { - m["stateful_allowlist"] = value + m["stateful_whitelist"] = value } } From 0a95e95853685765300c483245ce098b5c40f684 Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 8 Jul 2020 12:22:44 -0700 Subject: [PATCH 71/88] Canonicalize identity arithmetic operations with dynamic shapes PiperOrigin-RevId: 320240434 Change-Id: I5b095eeb95f42e6126939600aebf5efe802198e7 --- .../compiler/mlir/tensorflow/ir/tf_ops.cc | 64 ++++++++++++------- .../mlir/tensorflow/tests/canonicalize.mlir | 61 ++++++++++++++++++ 2 files changed, 103 insertions(+), 22 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index eb027748d28..a26de6b9c83 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -506,28 +506,52 @@ LogicalResult FoldOperandsPermutation( //===----------------------------------------------------------------------===// namespace { -// Folder that returns LHS of an Arithmetic Op if the RHS is a constant -// known to be Identity (e.g X+0) +// Fold Arithmetic Op if one of the operands is a constant known to be an +// Identity (e.g. X+0, X*1, etc...). For commutative operations fold if +// known identity value is either lhs or rhs. template < typename OpT, typename std::enable_if::value>::type * = nullptr> OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, ArrayRef operands) { - auto result_op_type = arithmetic_op.getResult().getType(); auto lhs_type = arithmetic_op.x().getType().template cast(); - if (!result_op_type.template cast().hasStaticShape()) return {}; + auto rhs_type = arithmetic_op.y().getType().template cast(); + auto result_type = + arithmetic_op.getResult().getType().template cast(); - // We only handle non-broadcastable case. - if (result_op_type != lhs_type) { - return {}; - } + // We can fold arithmetic operation only of we can prove that we will not + // accidentally hide a broadcasting error. + auto is_valid_broadcasting = [](ShapedType operand_ty, ShapedType identity_ty, + ShapedType result_ty) -> bool { + // Scalar identity is broadcastable to any operand shape, we only need to + // check that operand has the same shape as a result. + bool scalar_identity = identity_ty.hasRank() && identity_ty.getRank() == 0; + if (scalar_identity) return operand_ty == result_ty; + + // If identity is not a scalar, we must verify that all shapes are equal + // and statically known. + // + // TODO(ezhulenev): Fold if identity shape is statically know to be + // broadcastable to the operand shape. + return operand_ty == result_ty && identity_ty == result_ty && + result_ty.hasStaticShape(); + }; + + // Check that we have a constant operand on one side (candidate for identity). + const bool is_commutative = + (std::is_same::value || std::is_same::value); + auto lhs_attr = operands[0].dyn_cast_or_null(); + auto rhs_attr = operands[1].dyn_cast_or_null(); + if (!rhs_attr && !(is_commutative && lhs_attr)) return {}; // Mul and Div ops have identity value one while AddV2 and SubOp have identity // value zero. - int identity = + const int identity = (std::is_same::value || std::is_same::value || - std::is_same::value); + std::is_same::value) + ? 1 + : 0; Type element_ty = lhs_type.getElementType(); Attribute identity_attr; @@ -539,23 +563,19 @@ OpFoldResult IdentityArithmeticOpFolder(OpT arithmetic_op, return {}; } - if (auto attr = operands[1].dyn_cast_or_null()) { - if (attr.isSplat() && attr.getSplatValue() == identity_attr) + // Fold: Op(Operand, Identity) -> Operand. + if (rhs_attr && is_valid_broadcasting(lhs_type, rhs_type, result_type)) { + if (rhs_attr.isSplat() && rhs_attr.getSplatValue() == identity_attr) return arithmetic_op.x(); } - auto rhs_type = arithmetic_op.y().getType().template cast(); - // TODO(chhe): we could fold and add an identity to force the broadcast. - if (result_op_type != rhs_type) { - return {}; - } - - bool is_symmetric = - (std::is_same::value || std::is_same::value); - if (auto attr = operands[0].dyn_cast_or_null()) { - if (is_symmetric && attr.isSplat() && attr.getSplatValue() == identity_attr) + // Fold: Op(Identity, Operand) -> Operand for commutative operations. + if (lhs_attr && is_commutative && + is_valid_broadcasting(rhs_type, lhs_type, result_type)) { + if (lhs_attr.isSplat() && lhs_attr.getSplatValue() == identity_attr) return arithmetic_op.y(); } + return {}; } } // namespace diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index c5df001bf70..02a006130ad 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -190,6 +190,27 @@ func @testSubOfNeg(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8 // CHECK: return %0 } +// CHECK-LABEL: testSubOfZero +func @testSubOfZero(%arg0: tensor, %arg1: tensor<4x1xf32>) -> (tensor, tensor<4x1xf32>) { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + %1 = "tf.Sub"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "tf.Sub"(%arg1, %0) : (tensor<4x1xf32>, tensor) -> tensor<4x1xf32> + return %1, %2: tensor, tensor<4x1xf32> + +// CHECK: return %arg0, %arg1 +} + +// CHECK-LABEL: testSubOfZeroWithBroadcasting +func @testSubOfZeroWithBroadcasting(%arg0: tensor<4x1xf32>) -> tensor<4x4xf32> { + // This is an identity arithmetic operation, however we do not currently fold + // it because it has a broadcasting. + %0 = "tf.Const"() {value = dense<[[0.0, 0.0, 0.0, 0.0]]> : tensor<1x4xf32>} : () -> tensor<1x4xf32> + %1 = "tf.Sub"(%arg0, %0) : (tensor<4x1xf32>, tensor<1x4xf32>) -> tensor<4x4xf32> + return %1 : tensor<4x4xf32> + +// CHECK: return %1 +} + // CHECK-LABEL: testSquareOfSub func @testSquareOfSub(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> tensor<8x16xf32> { %0 = "tf.Sub"(%arg0, %arg1) : (tensor<8x16xf32>, tensor<8x16xf32>) -> tensor<8x16xf32> @@ -257,6 +278,46 @@ func @testAddV2OfNegRight(%arg0: tensor<8x16xf32>, %arg1: tensor<8x16xf32>) -> t // CHECK: return %0 } +// CHECK-LABEL: testAddV2IdentityScalar +func @testAddV2IdentityScalar(%arg0: tensor, %arg1: tensor, %arg2: tensor<4xf32>) -> (tensor, tensor, tensor<4xf32>) { + %0 = "tf.Const"() {value = dense<0.0> : tensor} : () -> tensor + + // Identity scalar (0.0) is foldable with operand of any shape because + // scalar is safely broadcastable to any shape. + + %1 = "tf.AddV2"(%arg0, %0) : (tensor, tensor) -> tensor + %2 = "tf.AddV2"(%arg1, %0) : (tensor, tensor) -> tensor + %3 = "tf.AddV2"(%arg2, %0) : (tensor<4xf32>, tensor) -> tensor<4xf32> + + %4 = "tf.AddV2"(%0, %1) : (tensor, tensor) -> tensor + %5 = "tf.AddV2"(%0, %2) : (tensor, tensor) -> tensor + %6 = "tf.AddV2"(%0, %3) : (tensor, tensor<4xf32>) -> tensor<4xf32> + + // CHECK: return %arg0, %arg1, %arg2 + return %4, %5, %6: tensor, tensor, tensor<4xf32> +} + +// CHECK-LABEL: testAddV2IdentityTensor +func @testAddV2IdentityTensor(%arg0: tensor, %arg1: tensor<4xf32>) -> (tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) { + %0 = "tf.Const"() {value = dense<[0.0, 0.0, 0.0, 0.0]> : tensor<4xf32>} : () -> tensor<4xf32> + + // If operand is a scalar, then the identity value (0.0 for addition) can + // be of any shape, because operand is safely broadcastable to any shape. + // + // However we can't fold this arithmetic operation because the operand + // shape does not match the result shape. + + %1 = "tf.AddV2"(%arg0, %0) : (tensor, tensor<4xf32>) -> tensor<4xf32> + %2 = "tf.AddV2"(%0, %arg0) : (tensor<4xf32>, tensor) -> tensor<4xf32> + + // If operand has the same shape as a result, we can fold it. + %3 = "tf.AddV2"(%arg1, %0) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + %4 = "tf.AddV2"(%0, %arg1) : (tensor<4xf32>, tensor<4xf32>) -> tensor<4xf32> + + // CHECK: return %1, %2, %arg1, %arg1 + return %1, %2, %3, %4: tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32> +} + // CHECK-LABEL: testDoubleConj func @testDoubleConj(%arg0: tensor<8x16x32x64xcomplex>) -> tensor<8x16x32x64xcomplex> { %0 = "tf.Conj"(%arg0) : (tensor<8x16x32x64xcomplex>) -> tensor<8x16x32x64xcomplex> From bf9ed625de0704f5ddb1d5aa203516ea414488b3 Mon Sep 17 00:00:00 2001 From: Yanhui Liang Date: Wed, 8 Jul 2020 12:40:32 -0700 Subject: [PATCH 72/88] Clean up the LSTM/GRU layer with the new grappler selector. PiperOrigin-RevId: 320244001 Change-Id: Iaf4285b437f64493d7b9bf8b9638c588b15e1237 --- .../python/keras/layers/recurrent_v2.py | 144 ++++++------------ 1 file changed, 47 insertions(+), 97 deletions(-) diff --git a/tensorflow/python/keras/layers/recurrent_v2.py b/tensorflow/python/keras/layers/recurrent_v2.py index 33babb54357..58eb0bb025b 100644 --- a/tensorflow/python/keras/layers/recurrent_v2.py +++ b/tensorflow/python/keras/layers/recurrent_v2.py @@ -385,6 +385,17 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): else: logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name) + # The first two attributes are added to support TFLite use case. + supportive_attributes = { + 'time_major': time_major, + 'go_backwards': go_backwards, + _FUNCTION_API_NAME_ATTRIBUTE: 'gru_' + str(uuid.uuid4()) + } + self.defun_gru_with_backend_selection = function.defun_with_attributes( + gru_with_backend_selection, + attributes=supportive_attributes, + autograph=False) + def build(self, input_shape): super(GRU, self).build(input_shape) @@ -467,7 +478,7 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): if dropout_mask is not None: inputs = inputs * dropout_mask[0] - gpu_gru_kwargs = { + gru_kwargs = { 'inputs': inputs, 'init_h': _read_variable_value(initial_state[0]), 'kernel': _read_variable_value(self.cell.kernel), @@ -476,29 +487,11 @@ class GRU(recurrent.DropoutRNNCellMixin, recurrent.GRU): 'mask': mask, 'time_major': self.time_major, 'go_backwards': self.go_backwards, - 'sequence_lengths': sequence_lengths + 'sequence_lengths': sequence_lengths, + 'zero_output_for_mask': self.zero_output_for_mask } - normal_gru_kwargs = gpu_gru_kwargs.copy() - normal_gru_kwargs.update({ - 'zero_output_for_mask': self.zero_output_for_mask, - }) - - if context.executing_eagerly(): - device_type = _get_context_device_type() - can_use_gpu = ( - # Either user specified GPU or unspecified but GPU is available. - (device_type == _GPU_DEVICE_NAME - or (device_type is None and context.num_gpus() > 0)) - and - (mask is None or is_sequence_right_padded(mask, self.time_major))) - # Under eager context, check the device placement and prefer the - if can_use_gpu: - last_output, outputs, new_h, runtime = gpu_gru(**gpu_gru_kwargs) - else: - last_output, outputs, new_h, runtime = standard_gru(**normal_gru_kwargs) - else: - last_output, outputs, new_h, runtime = gru_with_backend_selection( - **normal_gru_kwargs) + (last_output, outputs, new_h, + runtime) = self.defun_gru_with_backend_selection(**gru_kwargs) states = [new_h] return last_output, outputs, runtime, states @@ -765,24 +758,14 @@ def gru_with_backend_selection(inputs, init_h, kernel, recurrent_kernel, bias, true_fn=input_right_padded, false_fn=input_not_right_padded) - # Each time a `tf.function` is called, we will give it a unique - # identifiable API name, so that Grappler won't get confused when it - # sees multiple GRU layers added into same graph, and it will be able - # to pair up the different implementations across them. - api_name = 'gru_' + str(uuid.uuid4()) - supportive_attribute = { - 'time_major': time_major, - 'go_backwards': go_backwards, - } - defun_standard_gru = _generate_defun_backend( - api_name, _CPU_DEVICE_NAME, standard_gru, supportive_attribute) - defun_gpu_gru = _generate_defun_backend( - api_name, _GPU_DEVICE_NAME, gpu_gru_with_fallback, supportive_attribute) + # Chooses the implementation dynamicly based on the running device. + (last_output, outputs, new_h, + runtime) = control_flow_ops.execute_fn_for_device( + { + _CPU_DEVICE_NAME: lambda: standard_gru(**params), + _GPU_DEVICE_NAME: lambda: gpu_gru_with_fallback(**params) + }, lambda: standard_gru(**params)) - # Call the normal GRU impl and register the CuDNN impl function. The - # grappler will kick in during session execution to optimize the graph. - last_output, outputs, new_h, runtime = defun_standard_gru(**params) - function.register(defun_gpu_gru, **params) return last_output, outputs, new_h, runtime @@ -1097,6 +1080,18 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): else: logging.warn(_CUDNN_NOT_AVAILABLE_MSG % self.name) + # The first two attributes are added to support TFLite use case. + supportive_attributes = { + 'time_major': time_major, + 'go_backwards': go_backwards, + _FUNCTION_API_NAME_ATTRIBUTE: 'lstm_' + str(uuid.uuid4()) + } + + self.defun_lstm_with_backend_selection = function.defun_with_attributes( + lstm_with_backend_selection, + attributes=supportive_attributes, + autograph=False) + def call(self, inputs, mask=None, training=None, initial_state=None): # The input should be dense, padded with zeros. If a ragged input is fed # into the layer, it is padded and the row lengths are used for masking. @@ -1145,7 +1140,7 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): dropout_mask = self.get_dropout_mask_for_cell(inputs, training, count=4) if dropout_mask is not None: inputs = inputs * dropout_mask[0] - gpu_lstm_kwargs = { + lstm_kwargs = { 'inputs': inputs, 'init_h': _read_variable_value(initial_state[0]), 'init_c': _read_variable_value(initial_state[1]), @@ -1155,32 +1150,11 @@ class LSTM(recurrent.DropoutRNNCellMixin, recurrent.LSTM): 'mask': mask, 'time_major': self.time_major, 'go_backwards': self.go_backwards, - 'sequence_lengths': row_lengths - } - normal_lstm_kwargs = gpu_lstm_kwargs.copy() - normal_lstm_kwargs.update({ + 'sequence_lengths': row_lengths, 'zero_output_for_mask': self.zero_output_for_mask, - }) - - if context.executing_eagerly(): - device_type = _get_context_device_type() - can_use_gpu = ( - # Either user specified GPU or unspecified but GPU is available. - (device_type == _GPU_DEVICE_NAME - or (device_type is None and context.num_gpus() > 0)) - and - (mask is None or is_sequence_right_padded(mask, self.time_major))) - # Under eager context, check the device placement and prefer the - # GPU implementation when GPU is available. - if can_use_gpu: - last_output, outputs, new_h, new_c, runtime = gpu_lstm( - **gpu_lstm_kwargs) - else: - last_output, outputs, new_h, new_c, runtime = standard_lstm( - **normal_lstm_kwargs) - else: - (last_output, outputs, new_h, new_c, - runtime) = lstm_with_backend_selection(**normal_lstm_kwargs) + } + (last_output, outputs, new_h, new_c, + runtime) = self.defun_lstm_with_backend_selection(**lstm_kwargs) states = [new_h, new_c] @@ -1538,25 +1512,13 @@ def lstm_with_backend_selection(inputs, init_h, init_c, kernel, true_fn=input_right_padded, false_fn=input_not_right_padded) - # Each time a `tf.function` is called, we will give it a unique - # identifiable API name, so that Grappler won't get confused when it - # sees multiple LSTM layers added into same graph, and it will be able - # to pair up the different implementations across them. - api_name = 'lstm_' + str(uuid.uuid4()) - supportive_attribute = { - 'time_major': time_major, - 'go_backwards': go_backwards, - } - defun_standard_lstm = _generate_defun_backend( - api_name, _CPU_DEVICE_NAME, standard_lstm, supportive_attribute) - defun_gpu_lstm = _generate_defun_backend( - api_name, _GPU_DEVICE_NAME, gpu_lstm_with_fallback, supportive_attribute) - - # Call the normal LSTM impl and register the CuDNN impl function. The - # grappler will kick in during session execution to optimize the graph. - last_output, outputs, new_h, new_c, runtime = defun_standard_lstm( - **params) - function.register(defun_gpu_lstm, **params) + # Chooses the implementation dynamicly based on the running device. + (last_output, outputs, new_h, new_c, + runtime) = control_flow_ops.execute_fn_for_device( + { + _CPU_DEVICE_NAME: lambda: standard_lstm(**params), + _GPU_DEVICE_NAME: lambda: gpu_lstm_with_fallback(**params) + }, lambda: standard_lstm(**params)) return last_output, outputs, new_h, new_c, runtime @@ -1619,18 +1581,6 @@ def calculate_sequence_by_mask(mask, time_major): axis=timestep_index) -def _generate_defun_backend(unique_api_name, preferred_device, func, - supportive_attributes): - function_attributes = { - _FUNCTION_API_NAME_ATTRIBUTE: unique_api_name, - _FUNCTION_DEVICE_ATTRIBUTE: preferred_device, - } - function_attributes.update(supportive_attributes) - return function.defun_with_attributes(func=func, - attributes=function_attributes, - autograph=False) - - def _get_context_device_type(): """Parse the current context and return the device type, eg CPU/GPU.""" current_device = context.context().device_name From ee7172a929cb0c3d94a094fafc60bbaa175c085d Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 8 Jul 2020 12:56:45 -0700 Subject: [PATCH 73/88] Add tf.TPUPartitionedCall to TensorFlow MLIR ODS. This op is mostly auto generated from the TensorFlow op registry but with the op interface CallOpInterface added in, matching tf.PartitionedCall/tf.StatefulPartitionedCall. PiperOrigin-RevId: 320247409 Change-Id: I5db944a3aa3330f7238cfaef190e8210944eb811 --- .../compiler/mlir/tensorflow/ir/tf_ops.td | 31 +++++++++++++++++++ 1 file changed, 31 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td index 0753c76829c..d29addf66ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.td @@ -1165,4 +1165,35 @@ array([0, 2, 2]) ); } +def TF_TPUPartitionedCallOp : TF_Op<"TPUPartitionedCall", [CallOpInterface]> { + let summary = "Calls a function placed on a specified TPU device."; + + let arguments = (ins + Variadic:$args, + I32Tensor:$device_ordinal, + + SymbolRefAttr:$f, + DefaultValuedAttr:$autotuner_thresh + ); + + let results = (outs + Variadic:$output + ); + + TF_DerivedOperandTypeListAttr Tin = TF_DerivedOperandTypeListAttr<0>; + TF_DerivedResultTypeListAttr Tout = TF_DerivedResultTypeListAttr<0>; + + let extraClassDeclaration = [{ + // Gets the argument operands to the called function. + operand_range getArgOperands() { return args(); } + + // Returns the callee of this operation. + CallInterfaceCallable getCallableForCallee() { + return getAttrOfType("f"); + } + }]; + + let verifier = [{ return VerifyPartitionedCall(*this); }]; +} + #endif // TF_OPS From 80ac3b3bcc67eb298a3e24252d7786843068de26 Mon Sep 17 00:00:00 2001 From: Michael Kuchnik Date: Wed, 8 Jul 2020 13:06:09 -0700 Subject: [PATCH 74/88] [tf.data] Change buckets for tf.data's iterator.getNext() duration. PiperOrigin-RevId: 320249597 Change-Id: I7ab301007351628ceebb459e05ef9ceda73957f9 --- tensorflow/core/framework/metrics.cc | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/tensorflow/core/framework/metrics.cc b/tensorflow/core/framework/metrics.cc index 959abff5e7b..738863f3646 100644 --- a/tensorflow/core/framework/metrics.cc +++ b/tensorflow/core/framework/metrics.cc @@ -86,8 +86,9 @@ auto* tf_data_fingerprint_counter = monitoring::Counter<1>::New( auto* tf_data_getnext_duration_usecs_histogram = monitoring::Sampler<0>::New( {"/tensorflow/data/getnext_duration", "Microseconds spent fetching an element from tf.data Dataset iterator."}, - // Power of 2 with bucket count 10 (1024 microseconds) - {monitoring::Buckets::Exponential(1, 2, 10)}); + // Power of 2 with bucket count 10 (1024 microseconds) and 1 second. + {monitoring::Buckets::Explicit( + {2., 4., 8., 16., 32., 64., 128., 256., 512., 1024., 1e6})}); auto* tf_data_getnext_time_between_msecs_histogram = monitoring::Sampler<0>::New( From 26e2f2e285880dae50f976bfd64154190e85a274 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 8 Jul 2020 13:14:24 -0700 Subject: [PATCH 75/88] Fix cleanup of TFLite Java NNAPI delegate PiperOrigin-RevId: 320251156 Change-Id: I252fad8d6a5b79b57da706135a42d7e7970b8bd4 --- .../delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc index df2c4030e6c..c94a523d6c1 100644 --- a/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc +++ b/tensorflow/lite/delegates/nnapi/java/src/main/native/nnapi_delegate_jni.cc @@ -84,7 +84,7 @@ JNIEXPORT void JNICALL Java_org_tensorflow_lite_nnapi_NnApiDelegate_deleteDelegate(JNIEnv* env, jclass clazz, jlong delegate) { - delete reinterpret_cast(delegate); + delete reinterpret_cast(delegate); } #ifdef __cplusplus From 6ad4a657e06c80ed71736a93b95e7b643606de1d Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 8 Jul 2020 13:20:31 -0700 Subject: [PATCH 76/88] Add shape inference pass to the beginning of the TPUBridgeV1Compat pipeline. It is possible for the TPU computation to be in a function. Running shape inference prior to V1 specific passes allow for more granular shapes/types to be propagated to functions (e.g. resource subtypes), resulting in possibly less runtime/dynamic shapes for the TPU program. PiperOrigin-RevId: 320252292 Change-Id: I431ff20c943a9542fab2a0076675d62b22280cb8 --- tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc index 7b8ae474941..1963931b497 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/bridge.cc @@ -107,6 +107,7 @@ void CreateTPUBridgePipeline(OpPassManager &pm) { } void CreateTPUBridgePipelineV1(OpPassManager &pm) { + pm.addPass(TF::CreateTFShapeInferencePass()); // For V1 compatibility, we process a module where the graph does not have // feeds and fetched. We extract first the TPU computation in a submodule, // where it'll be in a function with args and returned values, much more like From 072f95d12b051ff337c25ef5a75dce5507e82dcd Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 13:46:19 -0700 Subject: [PATCH 77/88] Add pass to update tpu embedding ops pass. PiperOrigin-RevId: 320257175 Change-Id: I799e78da8f00e1097b0d5cf1eea23dfb4277389a --- tensorflow/compiler/mlir/tensorflow/BUILD | 1 + ...pu_update_embedding_enqueue_op_inputs.mlir | 79 +++++++++ .../mlir/tensorflow/transforms/passes.h | 5 + .../tpu_update_embedding_enqueue_op_inputs.cc | 165 ++++++++++++++++++ 4 files changed, 250 insertions(+) create mode 100644 tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir create mode 100644 tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc diff --git a/tensorflow/compiler/mlir/tensorflow/BUILD b/tensorflow/compiler/mlir/tensorflow/BUILD index 907bbb6734e..c4d1eb04edf 100644 --- a/tensorflow/compiler/mlir/tensorflow/BUILD +++ b/tensorflow/compiler/mlir/tensorflow/BUILD @@ -624,6 +624,7 @@ cc_library( "transforms/tpu_rewrite_pass.cc", "transforms/tpu_sharding_identification_pass.cc", "transforms/tpu_space_to_depth_pass.cc", + "transforms/tpu_update_embedding_enqueue_op_inputs.cc", "transforms/tpu_variable_runtime_reformatting.cc", "translate/breakup-islands.cc", "translate/tf_executor_to_functional.cc", diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir new file mode 100644 index 00000000000..5a26f5d905e --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir @@ -0,0 +1,79 @@ +// RUN: tf-opt %s -split-input-file -verify-diagnostics -tf-tpu-update-embedding-enqueue-op-inputs | FileCheck %s + +// CHECK-LABEL: func @check_enqueue_ops_update_for_eval +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor +func @check_enqueue_ops_update_for_eval(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + // CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"() + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + + // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_7]]) + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} + +// ----- + +// CHECK-LABEL: func @check_enqueue_ops_update_for_training +// CHECK-SAME: %[[ARG_0:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_1:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_2:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_3:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_4:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_5:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_6:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_7:[a-z0-9]*]]: tensor +// CHECK-SAME: %[[ARG_8:[a-z0-9]*]]: tensor +func @check_enqueue_ops_update_for_training(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + // CHECK: %[[CONST_0:[a-z0-9]*]] = "tf.Const"() + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + + %2 = "tf.Const"() {value = dense<0.0> : tensor<2x2xf32>} : () -> tensor<2x2xf32> + %3 = "tf.Const"() {value = dense<0.0> : tensor<4x4xf32>} : () -> tensor<4x4xf32> + "tf.SendTPUEmbeddingGradients"(%2, %3) {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D", operand_segment_sizes = dense<[2, 0]> : vector<2xi32>} : (tensor<2x2xf32>, tensor<4x4xf32>) -> () + + // CHECK: "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%[[ARG_0]], %[[ARG_1]], %[[ARG_2]], %[[ARG_3]], %[[ARG_4]], %[[ARG_5]], %[[CONST_0]], %[[CONST_0]], %[[CONST_0]], %[[ARG_6]]) + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call1", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + %4:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} + +// ----- + +func @check_enqueue_ops_with_different_attr_disallowed(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + // expected-error @+1 {{TPU embedding enqueue op must have corresponding RecvTPUEmbeddingActivations op}} + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} + +// ----- + +func @check_embedding_ops_with_missing_attribute_disallowed(%arg0: tensor, %arg1: tensor, + %arg2 :tensor, %arg3: tensor, %arg4: tensor, %arg5: tensor, + %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { + %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> + %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor + "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () + // expected-error @+1 {{missing required attribute: `_tpu_embedding_layer`}} + %2:2 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) + return +} diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h index 85efb761d8b..5af8a0195a4 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/passes.h +++ b/tensorflow/compiler/mlir/tensorflow/transforms/passes.h @@ -287,6 +287,11 @@ CreateTPUExtractHeadTailOutsideCompilationPass(); // that are only used for host computation. std::unique_ptr> CreateTPUHostComputationExpansionPass(); +// Creates a pass that updates inputs to TPU embedding layer enqueue ops so that +// correct ops are invoked during training and evaluation. +std::unique_ptr> +CreateTPUUpdateEmbeddingEnqueueOpInputsPass(); + // Creates a pass that extract outside compilation (CPU ops inside TPU cluster) // ops to a separate parallel_execute region to run on CPU. std::unique_ptr> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc new file mode 100644 index 00000000000..9638efcad86 --- /dev/null +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc @@ -0,0 +1,165 @@ +/* Copyright 2020 The TensorFlow Authors. All Rights Reserved. + +Licensed under the Apache License, Version 2.0 (the "License"); +you may not use this file except in compliance with the License. +You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + +Unless required by applicable law or agreed to in writing, software +distributed under the License is distributed on an "AS IS" BASIS, +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +See the License for the specific language governing permissions and +limitations under the License. +==============================================================================*/ + +#include "llvm/Support/Casting.h" +#include "mlir/IR/Attributes.h" // from @llvm-project +#include "mlir/IR/Block.h" // from @llvm-project +#include "mlir/IR/Builders.h" // from @llvm-project +#include "mlir/IR/Operation.h" // from @llvm-project +#include "mlir/IR/Value.h" // from @llvm-project +#include "mlir/IR/Visitors.h" // from @llvm-project +#include "mlir/Pass/Pass.h" // from @llvm-project +#include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" +#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" + +namespace mlir { +namespace TFTPU { +namespace { + +constexpr char kTPUEmbeddingAttr[] = "_tpu_embedding_layer"; + +struct TPUUpdateEmbeddingEnqueueOpInputs + : public PassWrapper { + void runOnFunction() override; +}; + +// Extracts `_tpu_embedding_layer` attribute from TPU embedding ops and +// clear the attribute from the operation. This ensures that future optimization +// passes does not trigger additional logic due to presence of this attribute. +LogicalResult ExtractEmbeddingAttribute( + Operation* op, std::map* embedding_op_map) { + auto embedding_attr = op->getAttrOfType(kTPUEmbeddingAttr); + if (!embedding_attr) + return op->emitOpError( + "missing required attribute: `_tpu_embedding_layer`"); + + auto it = embedding_op_map->emplace(embedding_attr.getValue().str(), op); + if (!it.second) + return op->emitOpError( + "found duplicate tpu embedding ops. This usually happens when " + "there are multiple TPUEmbedding layers."); + + op->removeAttr(kTPUEmbeddingAttr); + return success(); +} + +LogicalResult FindTPUEmbeddingOps( + FuncOp func_op, std::map* enqueue_op_map, + std::map* recv_activation_op_map, + std::map* send_gradient_op_map) { + auto walk_result = func_op.walk([&](Operation* op) { + if (llvm::isa(op)) { + if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map))) + return WalkResult::interrupt(); + } + + if (llvm::isa(op)) { + if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map))) + return WalkResult::interrupt(); + } + + if (llvm::isa(op) || + llvm::isa(op)) { + if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map))) + return WalkResult::interrupt(); + } + return WalkResult::advance(); + }); + if (walk_result.wasInterrupted()) return failure(); + + return success(); +} + +// Updates the operand of TPU embedding enqueue ops depending on whether +// the graph is in training mode or in non-training mode. +// If SendTPUEmbeddingGradients op is present, this means that graph is in +// training mode. As so, correctly feed in `then` branch value of SelectV2 +// operand as inputs to the TPU embedding enqueue ops. +LogicalResult UpdateEmbeddingEnqueueOpInput( + const std::map& enqueue_op_map, + const std::map& recv_activation_op_map, + const std::map& send_gradient_op_map) { + for (const auto& it : enqueue_op_map) { + const auto embedding_attr = it.first; + Operation* embedding_op = it.second; + if (!recv_activation_op_map.count(embedding_attr)) + return embedding_op->emitOpError( + "TPU embedding enqueue op must have corresponding " + "RecvTPUEmbeddingActivations op"); + + // TPU Embedding enqueue ops take different inputs depending on whether + // graph is in training mode or in eval/prediction mode. The inputs to the + // enqueue ops are present/listed as operands to SelectV2 op. Then branch + // operand of the SelectV2 op represents input to take during training + // and else branch operand represents input to take during + // prediction/evaluation. If SendTPUEmbeddingGradients op exists in the + // graph, then graph is in training mode, so correctly forward the input + // of SelectV2 op as operand to the TPU embedding enqueue op. + bool is_training = send_gradient_op_map.count(embedding_attr); + for (auto enqueue_operand : embedding_op->getOperands()) { + if (auto select = llvm::dyn_cast_or_null( + enqueue_operand.getDefiningOp())) { + enqueue_operand.replaceAllUsesWith(is_training ? select.t() + : select.e()); + } + } + } + + return success(); +} + +void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() { + OpBuilder builder(&getContext()); + auto func_op = getFunction(); + + // All TPU embedding layer related ops are annotated with + // `_tpu_embedding_layer` attribute along with corresponding string attribute. + // Store all tpu embedding layer related ops with value of + // `_tpu_embedding_layer` attribute as map key. + std::map enqueue_op_map; + std::map recv_activation_op_map; + std::map send_gradient_op_map; + if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map, + &recv_activation_op_map, + &send_gradient_op_map))) + return signalPassFailure(); + + if (enqueue_op_map.size() != recv_activation_op_map.size()) { + func_op.emitError( + "Number of embedding enqueue ops must match the number " + "of RecvTPUEmbeddingActivation op"); + return signalPassFailure(); + } + + if (failed(UpdateEmbeddingEnqueueOpInput( + enqueue_op_map, recv_activation_op_map, send_gradient_op_map))) + return signalPassFailure(); +} + +} // anonymous namespace + +std::unique_ptr> +CreateTPUUpdateEmbeddingEnqueueOpInputsPass() { + return std::make_unique(); +} + +static PassRegistration pass( + "tf-tpu-update-embedding-enqueue-op-inputs", + "Updates inputs to TPU embedding enqueue ops depending on whether graph " + "is in training mode or in evaluation mode."); + +} // namespace TFTPU +} // namespace mlir From 41b7f9c61b7b28cecfe0e39a7ea09e84caa9dfc9 Mon Sep 17 00:00:00 2001 From: Francois Chollet Date: Wed, 8 Jul 2020 13:47:06 -0700 Subject: [PATCH 78/88] Fix critical bug with `add_loss` TFOpLayer graph construction that caused incorrect loss values and backprop issues PiperOrigin-RevId: 320257330 Change-Id: I0a030bc7632735b152454657fd15e41539b4e4bd --- tensorflow/python/keras/backend.py | 6 ++-- .../python/keras/engine/functional_test.py | 32 +++++++++++++++++++ 2 files changed, 35 insertions(+), 3 deletions(-) diff --git a/tensorflow/python/keras/backend.py b/tensorflow/python/keras/backend.py index d5f6db7d7e8..1dc11052edb 100644 --- a/tensorflow/python/keras/backend.py +++ b/tensorflow/python/keras/backend.py @@ -4693,7 +4693,7 @@ def categorical_crossentropy(target, output, from_logits=False, axis=-1): labels=target, logits=output, axis=axis) if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and - output.op.type == 'Softmax'): + output.op.type == 'Softmax') and not hasattr(output, '_keras_history'): # When softmax activation function is used for output operation, we # use logits from the softmax function directly to compute loss in order # to prevent collapsing zero when training. @@ -4738,7 +4738,7 @@ def sparse_categorical_crossentropy(target, output, from_logits=False, axis=-1): if (not from_logits and not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and - output.op.type == 'Softmax'): + output.op.type == 'Softmax') and not hasattr(output, '_keras_history'): # When softmax activation function is used for output operation, we # use logits from the softmax function directly to compute loss in order # to prevent collapsing zero when training. @@ -4817,7 +4817,7 @@ def binary_crossentropy(target, output, from_logits=False): return nn.sigmoid_cross_entropy_with_logits(labels=target, logits=output) if (not isinstance(output, (ops.EagerTensor, variables_module.Variable)) and - output.op.type == 'Sigmoid'): + output.op.type == 'Sigmoid') and not hasattr(output, '_keras_history'): # When sigmoid activation function is used for output operation, we # use logits from the sigmoid function directly to compute loss in order # to prevent collapsing zero when training. diff --git a/tensorflow/python/keras/engine/functional_test.py b/tensorflow/python/keras/engine/functional_test.py index db7b3d696ab..47e4dc488a3 100644 --- a/tensorflow/python/keras/engine/functional_test.py +++ b/tensorflow/python/keras/engine/functional_test.py @@ -34,6 +34,7 @@ from tensorflow.python.keras import combinations from tensorflow.python.keras import initializers from tensorflow.python.keras import keras_parameterized from tensorflow.python.keras import layers +from tensorflow.python.keras import losses from tensorflow.python.keras import models from tensorflow.python.keras import testing_utils from tensorflow.python.keras.engine import base_layer @@ -1904,6 +1905,37 @@ class AddLossTest(keras_parameterized.TestCase): self.assertAllClose(model.get_weights(), model2.get_weights()) + def test_add_loss_crossentropy_backtracking(self): + inputs = input_layer_lib.Input((2,)) + labels = input_layer_lib.Input((1,)) + outputs = layers.Dense(1, activation='sigmoid')(inputs) + model = functional.Functional([inputs, labels], outputs) + model.add_loss(losses.binary_crossentropy(labels, outputs)) + model.compile('adam') + x = np.random.random((2, 2)) + y = np.random.random((2, 1)) + model.fit([x, y]) + + inputs = input_layer_lib.Input((2,)) + labels = input_layer_lib.Input((2,)) + outputs = layers.Dense(2, activation='softmax')(inputs) + model = functional.Functional([inputs, labels], outputs) + model.add_loss(losses.categorical_crossentropy(labels, outputs)) + model.compile('adam') + x = np.random.random((2, 2)) + y = np.random.random((2, 2)) + model.fit([x, y]) + + inputs = input_layer_lib.Input((2,)) + labels = input_layer_lib.Input((1,), dtype='int32') + outputs = layers.Dense(2, activation='softmax')(inputs) + model = functional.Functional([inputs, labels], outputs) + model.add_loss(losses.sparse_categorical_crossentropy(labels, outputs)) + model.compile('adam') + x = np.random.random((2, 2)) + y = np.random.randint(0, 2, size=(2, 1)) + model.fit([x, y]) + @combinations.generate(combinations.keras_mode_combinations()) class WeightAccessTest(keras_parameterized.TestCase): From 1404718bfaf2bcf21ecb30cc89d67158d241506f Mon Sep 17 00:00:00 2001 From: Eugene Zhulenev Date: Wed, 8 Jul 2020 13:55:14 -0700 Subject: [PATCH 79/88] [MLIR:TF] Canonicalize redundant reshape PiperOrigin-RevId: 320258973 Change-Id: I9064c42b70eddd0363cf6966344c39384561e7be --- .../mlir/tensorflow/ir/tf_generated_ops.td | 2 ++ tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc | 5 +++++ .../mlir/tensorflow/tests/canonicalize.mlir | 14 ++++++++++++++ .../mlir/tensorflow/transforms/canonicalize.td | 7 +++++++ 4 files changed, 28 insertions(+) diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td index 1edac0e535f..9612caa8782 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_generated_ops.td @@ -7270,6 +7270,8 @@ reshape(t, []) ==> 7 let verifier = [{ return Verify(*this); }]; + + let hasCanonicalizer = 1; } def TF_ResizeBilinearOp : TF_Op<"ResizeBilinear", [NoSideEffect]> { diff --git a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc index a26de6b9c83..e16318f8e11 100644 --- a/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc +++ b/tensorflow/compiler/mlir/tensorflow/ir/tf_ops.cc @@ -2889,6 +2889,11 @@ void ReshapeOp::build(OpBuilder &builder, OperationState &result, Value tensor, return unranked(); } +void ReshapeOp::getCanonicalizationPatterns(OwningRewritePatternList &results, + MLIRContext *context) { + results.insert(context); +} + //===----------------------------------------------------------------------===// // SelectOp //===----------------------------------------------------------------------===// diff --git a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir index 02a006130ad..8597740a4ae 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/canonicalize.mlir @@ -363,6 +363,20 @@ func @testDoubleReciprocal(%arg0: tensor<8x16x32x64xi32>) -> tensor<8x16x32x64xi // CHECK: return %arg0 } +// CHECK-LABEL: testRedundantReshape +func @testRedundantReshape(%arg0: tensor<4x4xi32>) -> tensor<2x8xi32> { + %0 = "tf.Const"() {value = dense<[8, 2]> : tensor<2xi32>} : () -> tensor<2xi32> + %1 = "tf.Const"() {value = dense<[2, 8]> : tensor<2xi32>} : () -> tensor<2xi32> + %2 = "tf.Reshape"(%arg0, %0) : (tensor<4x4xi32>, tensor<2xi32>) -> tensor<8x2xi32> + %3 = "tf.Reshape"(%2, %1) : (tensor<8x2xi32>, tensor<2xi32>) -> tensor<2x8xi32> + return %3: tensor<2x8xi32> + + // CHECK: %0 = "tf.Const" + // CHECK-SAME: value = dense<[2, 8]> : tensor<2xi32> + // CHECK: %1 = "tf.Reshape"(%arg0, %0) + // CHECK: return %1 : tensor<2x8xi32> +} + // CHECK-LABEL: testSelectScalarPred func @testSelectScalarPred(%arg0: tensor, %arg1: tensor<4x2xf16>, %arg2: tensor<4x2xf16>) -> tensor<4x2xf16> { // CHECK-NEXT: "tf.SelectV2"(%arg0, %arg1, %arg2) : (tensor, tensor<4x2xf16>, tensor<4x2xf16>) -> tensor<4x2xf16> diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td index e6a9ce4ad62..9d72284da91 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td +++ b/tensorflow/compiler/mlir/tensorflow/transforms/canonicalize.td @@ -194,6 +194,13 @@ def RealDivWithSqrtDivisor : Pat<(TF_RealDivOp $arg0, (TF_SqrtOp $arg1)), def ReciprocalNested : Pat<(TF_ReciprocalOp (TF_ReciprocalOp $arg)), (replaceWithValue $arg)>; +//===----------------------------------------------------------------------===// +// Reshape op patterns. +//===----------------------------------------------------------------------===// + +def RedundantReshape : Pat<(TF_ReshapeOp (TF_ReshapeOp $arg, $unused), $shape), + (TF_ReshapeOp $arg, $shape)>; + //===----------------------------------------------------------------------===// // Select op patterns. //===----------------------------------------------------------------------===// From d59724604c8761f2472ae9e9013550828a0b04de Mon Sep 17 00:00:00 2001 From: Mangpo Phothilimthana Date: Wed, 8 Jul 2020 14:00:26 -0700 Subject: [PATCH 80/88] HloModule::Clone function copies input_output_alias_config. Without this copy, the cloned module is not identical to the original module because it does not have the input_output_alias_config. PiperOrigin-RevId: 320260067 Change-Id: I2c1b8cdda50bc253fd30d844c3b1840c0bcf596c --- tensorflow/compiler/xla/service/hlo_module.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/tensorflow/compiler/xla/service/hlo_module.cc b/tensorflow/compiler/xla/service/hlo_module.cc index c0b733c19cd..c715d016c4f 100644 --- a/tensorflow/compiler/xla/service/hlo_module.cc +++ b/tensorflow/compiler/xla/service/hlo_module.cc @@ -690,6 +690,7 @@ std::unique_ptr HloModule::Clone(const HloModuleConfig& config, HloCloneContext context(module.get(), suffix); auto cloned_computation = entry_computation_->Clone(suffix, &context); module->AddEntryComputation(std::move(cloned_computation)); + module->input_output_alias_config() = input_output_alias_config(); if (has_schedule() && schedule().Verify().ok()) { HloSchedule clone_schedule(module.get()); From 93990d3a3cdaf17ec730fcbb798ecd479d0bbfaf Mon Sep 17 00:00:00 2001 From: Ran Chen Date: Wed, 8 Jul 2020 14:34:05 -0700 Subject: [PATCH 81/88] Update the docstring of reduce_to and batch_reduce_to PiperOrigin-RevId: 320267089 Change-Id: I09821547a91cbc7441ac6aa0e188a614a5055ddb --- .../python/distribute/distribute_lib.py | 129 +++++++++++++++--- 1 file changed, 113 insertions(+), 16 deletions(-) diff --git a/tensorflow/python/distribute/distribute_lib.py b/tensorflow/python/distribute/distribute_lib.py index e82e60eddb6..d398a850a41 100644 --- a/tensorflow/python/distribute/distribute_lib.py +++ b/tensorflow/python/distribute/distribute_lib.py @@ -2239,20 +2239,72 @@ class StrategyExtendedV2(object): def reduce_to(self, reduce_op, value, destinations, experimental_hints=None): """Combine (via e.g. sum or mean) values across replicas. + `reduce_to` aggregates `tf.distribute.DistributedValues` and distributed + variables. It supports both dense values and `tf.IndexedSlices`. + + This API currently can only be called in cross-replica context. Other + variants to reduce values across replicas are: + * `tf.distribute.StrategyExtended.batch_reduce_to`: the batch version of + this API. + * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API + in replica context. It supports both batched and non-batched all-reduce. + * `tf.distribute.Strategy.reduce`: a more convenient method to reduce + to the host in cross-replica context. + + `destinations` specifies where to reduce the value to, e.g. "GPU:0". You can + also pass in a `Tensor`, and the destinations will be the device of that + tensor. For all-reduce, pass the same to `value` and `destinations`. + + It can be used in `tf.distribute.ReplicaContext.merge_call` to write code + that works for all `tf.distribute.Strategy`. + + >>> @tf.function + ... def step_fn(var): + ... + ... def merge_fn(strategy, value, var): + ... # All-reduce the value. Note that `value` here is a + ... # `tf.distribute.DistributedValues`. + ... reduced = strategy.extended.reduce_to(tf.distribute.ReduceOp.SUM, + ... value, destinations=var) + ... strategy.extended.update(var, lambda var, value: var.assign(value), + ... args=(reduced,)) + ... + ... value = tf.identity(1.) + ... tf.distribute.get_replica_context().merge_call(merge_fn, + ... args=(value, var)) + >>> + >>> def run(strategy): + ... with strategy.scope(): + ... v = tf.Variable(0.) + ... strategy.run(step_fn, args=(v,)) + ... return v + >>> + >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) + MirroredVariable:{ + 0: , + 1: + } + >>> run(tf.distribute.experimental.CentralStorageStrategy( + ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) + + >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) + + Args: - reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. - value: A per-replica value with one value per replica. - destinations: A mirrored variable, a per-replica tensor, or a device - string. The return value will be copied to all destination devices (or - all the devices where the `destinations` value resides). To perform an - all-reduction, pass `value` to `destinations`. - experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints - to perform collective operations. + reduce_op: a `tf.distribute.ReduceOp` or string. How to reduce the value. + value: a `tf.distribute.DistributedValue`, or a `tf.Tensor` like object. + destinations: a `tf.distribute.DistributedValue`, a `tf.Variable`, a + `tf.Tensor` alike object, or a device string. It specifies the devices + to reduce to. To perform an all-reduce, pass the same to `value` and + `destinations`. Note that if it's a `tf.Variable`, the value is reduced + to the devices of that variable, this method doesn't update the variable. + experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints + to perform collective operations. See + `tf.distrbute.experimental.CollectiveHints` for details. Returns: - A tensor or value mirrored to `destinations`. + A tensor or value reduced to `destinations`. """ - # TODO(josh11b): More docstring _require_cross_replica_or_default_context_extended(self) assert not isinstance(destinations, (list, tuple)) assert not isinstance(reduce_op, variable_scope.VariableAggregation) @@ -2273,17 +2325,62 @@ class StrategyExtendedV2(object): experimental_hints=None): """Combine multiple `reduce_to` calls into one for faster execution. + Similar to `reduce_to`, but accepts a list of (value, destinations) pairs. + It's more efficient than reduce each value separately. + + This API currently can only be called in cross-replica context. Other + variants to reduce values across replicas are: + * `tf.distribute.StrategyExtended.reduce_to`: the non-batch version of + this API. + * `tf.distribute.ReplicaContext.all_reduce`: the counterpart of this API + in replica context. It supports both batched and non-batched all-reduce. + * `tf.distribute.Strategy.reduce`: a more convenient method to reduce + to the host in cross-replica context. + + See `reduce_to` for more information. + + >>> @tf.function + ... def step_fn(var): + ... + ... def merge_fn(strategy, value, var): + ... # All-reduce the value. Note that `value` here is a + ... # `tf.distribute.DistributedValues`. + ... reduced = strategy.extended.batch_reduce_to( + ... tf.distribute.ReduceOp.SUM, [(value, var)])[0] + ... strategy.extended.update(var, lambda var, value: var.assign(value), + ... args=(reduced,)) + ... + ... value = tf.identity(1.) + ... tf.distribute.get_replica_context().merge_call(merge_fn, + ... args=(value, var)) + >>> + >>> def run(strategy): + ... with strategy.scope(): + ... v = tf.Variable(0.) + ... strategy.run(step_fn, args=(v,)) + ... return v + >>> + >>> run(tf.distribute.MirroredStrategy(["GPU:0", "GPU:1"])) + MirroredVariable:{ + 0: , + 1: + } + >>> run(tf.distribute.experimental.CentralStorageStrategy( + ... compute_devices=["GPU:0", "GPU:1"], parameter_device="CPU:0")) + + >>> run(tf.distribute.OneDeviceStrategy("GPU:0")) + + Args: - reduce_op: Reduction type, an instance of `tf.distribute.ReduceOp` enum. - value_destination_pairs: A sequence of (value, destinations) pairs. See - `reduce_to()` for a description. - experimental_hints: A `tf.distrbute.experimental.CollectiveHints`. Hints + reduce_op: a `tf.distribute.ReduceOp`. How to reduce the value. + value_destination_pairs: a sequence of (value, destinations) pairs. See + `reduce_to()` for descriptions. + experimental_hints: a `tf.distrbute.experimental.CollectiveHints`. Hints to perform collective operations. Returns: - A list of mirrored values, one per pair in `value_destination_pairs`. + A list of reduced values, one per pair in `value_destination_pairs`. """ - # TODO(josh11b): More docstring _require_cross_replica_or_default_context_extended(self) assert not isinstance(reduce_op, variable_scope.VariableAggregation) if isinstance(reduce_op, six.string_types): From 84d017ca0ec55d2eaf577799b9f360071f19b13d Mon Sep 17 00:00:00 2001 From: Lucy Fox Date: Wed, 8 Jul 2020 14:44:35 -0700 Subject: [PATCH 82/88] Update Grappler Remapper documentation to include DepthwiseConv2dNative fusion. PiperOrigin-RevId: 320269324 Change-Id: I9defa2bfe2deff9d709eca0e308f56637d896a2e --- tensorflow/core/grappler/optimizers/remapper.cc | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/tensorflow/core/grappler/optimizers/remapper.cc b/tensorflow/core/grappler/optimizers/remapper.cc index 44e6174970e..46c7afbc53a 100644 --- a/tensorflow/core/grappler/optimizers/remapper.cc +++ b/tensorflow/core/grappler/optimizers/remapper.cc @@ -47,10 +47,15 @@ namespace grappler { // MatMul + ... -> _FusedMatMul: // (1) MatMul + BiasAdd + // +// DepthwiseConv2dNative + ... -> _FusedDepthwiseConv2dNative: +// (1) DepthwiseConv2dNative + BiasAdd + +// // FusedBatchNorm[$is_training] + ... -> _FusedBatchNormEx[$is_training] // (1) FusedBatchNorm + // (2) FusedBatchNorm + SideInput + // +// In all cases, the supported activation functions are Relu, Relu6, and Elu. +// // Both Conv2D and MatMul implemented as Tensor contraction (on CPU), so all the // patterns are "ContractionWith...". namespace { From c38762b4c3ab3abdd6ebc1fd1550a9e5cb190595 Mon Sep 17 00:00:00 2001 From: Jared Duke Date: Wed, 8 Jul 2020 14:47:01 -0700 Subject: [PATCH 83/88] Reland "Fix issues with 32-bit ARM builds" PiperOrigin-RevId: 320269794 Change-Id: I9b7eb7561ad6c1d81d398745520e279b99c8dafe --- tensorflow/BUILD | 30 +++++++++++++++++++++++++++++ tensorflow/core/kernels/BUILD | 14 ++++++-------- tensorflow/core/platform/platform.h | 14 +++++++++----- 3 files changed, 45 insertions(+), 13 deletions(-) diff --git a/tensorflow/BUILD b/tensorflow/BUILD index d00608ccc98..0bd3724301f 100644 --- a/tensorflow/BUILD +++ b/tensorflow/BUILD @@ -260,6 +260,36 @@ config_setting( visibility = ["//visibility:public"], ) +config_setting( + name = "armeabi", + values = {"cpu": "armeabi"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "armeabi-v7a", + values = {"cpu": "armeabi-v7a"}, + visibility = ["//visibility:public"], +) + +config_setting( + name = "arm64-v8a", + values = {"cpu": "arm64-v8a"}, + visibility = ["//visibility:public"], +) + +selects.config_setting_group( + name = "arm_any", + match_any = [ + ":arm", + ":armeabi", + ":armeabi-v7a", + ":arm64-v8a", + ":linux_aarch64", + ":linux_armhf", + ], +) + config_setting( name = "freebsd", values = {"cpu": "freebsd"}, diff --git a/tensorflow/core/kernels/BUILD b/tensorflow/core/kernels/BUILD index 9ace481a991..0dee2a48e76 100644 --- a/tensorflow/core/kernels/BUILD +++ b/tensorflow/core/kernels/BUILD @@ -817,10 +817,9 @@ cc_library( srcs = ["eigen_contraction_kernel.cc"], hdrs = ["eigen_contraction_kernel.h"], defines = select({ - "//tensorflow:android": [], - "//tensorflow:arm": [], + "//tensorflow:android_x86": [], + "//tensorflow:arm_any": [], "//tensorflow:ios": [], - "//tensorflow:linux_aarch64": [], "//tensorflow:linux_ppc64le": [], "//conditions:default": [ "TENSORFLOW_USE_CUSTOM_CONTRACTION_KERNEL", @@ -832,10 +831,9 @@ cc_library( "//third_party/eigen3", "//tensorflow/core/platform:dynamic_annotations", ] + select({ - "//tensorflow:android": [], - "//tensorflow:arm": [], + "//tensorflow:android_x86": [], + "//tensorflow:arm_any": [], "//tensorflow:ios": [], - "//tensorflow:linux_aarch64": [], "//tensorflow:linux_ppc64le": [], "//conditions:default": ["@mkl_dnn//:mkldnn_single_threaded"], }), @@ -3224,8 +3222,8 @@ tf_cc_test( name = "eigen_mkldnn_contraction_kernel_test", size = "small", srcs = select({ - "//tensorflow:android": [], - "//tensorflow:arm": [], + "//tensorflow:android_x86": [], + "//tensorflow:arm_any": [], "//tensorflow:ios": [], "//tensorflow:linux_ppc64le": [], ":no_mkldnn_contraction_kernel": [], diff --git a/tensorflow/core/platform/platform.h b/tensorflow/core/platform/platform.h index a840d7b06e3..3375a6e50eb 100644 --- a/tensorflow/core/platform/platform.h +++ b/tensorflow/core/platform/platform.h @@ -41,18 +41,22 @@ limitations under the License. #elif defined(_WIN32) #define PLATFORM_WINDOWS -#elif defined(__arm__) -#define PLATFORM_POSIX - #elif defined(__EMSCRIPTEN__) #define PLATFORM_PORTABLE_GOOGLE #define PLATFORM_POSIX +// EMSCRIPTEN builds are considered "mobile" for the sake of portability. +#define IS_MOBILE_PLATFORM + +#elif defined(__arm__) || defined(__aarch64__) +// If no platform specified, use: +#define PLATFORM_POSIX // Require an outside macro to tell us if we're building for Raspberry Pi or // another ARM device that's not a mobile platform. -#if !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE) +#if !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE) && \ + !defined(PLATFORM_GOOGLE) #define IS_MOBILE_PLATFORM -#endif // !defined(RASPBERRY_PI) && !defined(ARM_NON_MOBILE) +#endif #else // If no platform specified, use: From e1fcc870965fdd49a05e5ad6b2f6520494898585 Mon Sep 17 00:00:00 2001 From: Andy Ly Date: Wed, 8 Jul 2020 14:47:56 -0700 Subject: [PATCH 84/88] Update TPUUpdateEmbeddingEnqueueOpInputs pass to use llvm::StringMap instead of std::map and fix error message style. (NFC) PiperOrigin-RevId: 320269989 Change-Id: I0358e1345ccd2bc2f93adde5d793b62960f7f205 --- ...pu_update_embedding_enqueue_op_inputs.mlir | 4 +- .../tpu_update_embedding_enqueue_op_inputs.cc | 63 +++++++++---------- 2 files changed, 32 insertions(+), 35 deletions(-) diff --git a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir index 5a26f5d905e..b77e4b1fbd0 100644 --- a/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir +++ b/tensorflow/compiler/mlir/tensorflow/tests/tpu_update_embedding_enqueue_op_inputs.mlir @@ -59,7 +59,7 @@ func @check_enqueue_ops_with_different_attr_disallowed(%arg0: tensor, % %arg6: tensor, %arg7: tensor, %arg8: tensor) -> () { %0 = "tf.Const"() {value = dense<[]> : tensor<0xf32>} : () -> tensor<0xf32> %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor - // expected-error @+1 {{TPU embedding enqueue op must have corresponding RecvTPUEmbeddingActivations op}} + // expected-error @+1 {{'tf.EnqueueTPUEmbeddingSparseTensorBatch' op must have a corresponding 'tf.RecvTPUEmbeddingActivations' op}} "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () %2:2 = "tf.RecvTPUEmbeddingActivations"() {_tpu_embedding_layer = "call1", config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) return @@ -73,7 +73,7 @@ func @check_embedding_ops_with_missing_attribute_disallowed(%arg0: tensor : tensor<0xf32>} : () -> tensor<0xf32> %1 = "tf.SelectV2"(%arg8, %arg6, %arg7) : (tensor, tensor, tensor) -> tensor "tf.EnqueueTPUEmbeddingSparseTensorBatch"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %0, %0, %0, %1) {_tpu_embedding_layer = "call_123", _xla_outside_compilation = "0", combiners = ["mean", "sum"], device_ordinal = -1 : i64, max_sequence_lengths = [0, 0, 0], table_ids = [1, 1, 0]} : (tensor, tensor, tensor, tensor, tensor, tensor, tensor<0xf32>, tensor<0xf32>, tensor<0xf32>, tensor) -> () - // expected-error @+1 {{missing required attribute: `_tpu_embedding_layer`}} + // expected-error @+1 {{'tf.RecvTPUEmbeddingActivations' op requires attribute '_tpu_embedding_layer'}} %2:2 = "tf.RecvTPUEmbeddingActivations"() {config = "\0A\0B\0C\0D"} : () -> (tensor<2x2xf32>, tensor<4x4xf32>) return } diff --git a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc index 9638efcad86..f3588c8359b 100644 --- a/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc +++ b/tensorflow/compiler/mlir/tensorflow/transforms/tpu_update_embedding_enqueue_op_inputs.cc @@ -13,6 +13,7 @@ See the License for the specific language governing permissions and limitations under the License. ==============================================================================*/ +#include "llvm/ADT/StringMap.h" #include "llvm/Support/Casting.h" #include "mlir/IR/Attributes.h" // from @llvm-project #include "mlir/IR/Block.h" // from @llvm-project @@ -22,6 +23,7 @@ limitations under the License. #include "mlir/IR/Visitors.h" // from @llvm-project #include "mlir/Pass/Pass.h" // from @llvm-project #include "mlir/Pass/PassRegistry.h" // from @llvm-project +#include "mlir/Support/LogicalResult.h" // from @llvm-project #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" @@ -40,47 +42,41 @@ struct TPUUpdateEmbeddingEnqueueOpInputs // clear the attribute from the operation. This ensures that future optimization // passes does not trigger additional logic due to presence of this attribute. LogicalResult ExtractEmbeddingAttribute( - Operation* op, std::map* embedding_op_map) { + Operation* op, llvm::StringMap* embedding_op_map) { auto embedding_attr = op->getAttrOfType(kTPUEmbeddingAttr); if (!embedding_attr) - return op->emitOpError( - "missing required attribute: `_tpu_embedding_layer`"); + return op->emitOpError("requires attribute '_tpu_embedding_layer'"); - auto it = embedding_op_map->emplace(embedding_attr.getValue().str(), op); - if (!it.second) + if (!embedding_op_map->insert({embedding_attr.getValue(), op}).second) return op->emitOpError( - "found duplicate tpu embedding ops. This usually happens when " - "there are multiple TPUEmbedding layers."); + "found duplicate TPU embedding ops potentially from multiple " + "TPUEmbedding layers"); op->removeAttr(kTPUEmbeddingAttr); return success(); } LogicalResult FindTPUEmbeddingOps( - FuncOp func_op, std::map* enqueue_op_map, - std::map* recv_activation_op_map, - std::map* send_gradient_op_map) { + FuncOp func_op, llvm::StringMap* enqueue_op_map, + llvm::StringMap* recv_activation_op_map, + llvm::StringMap* send_gradient_op_map) { auto walk_result = func_op.walk([&](Operation* op) { - if (llvm::isa(op)) { + if (llvm::isa(op)) if (failed(ExtractEmbeddingAttribute(op, recv_activation_op_map))) return WalkResult::interrupt(); - } - if (llvm::isa(op)) { + if (llvm::isa(op)) if (failed(ExtractEmbeddingAttribute(op, send_gradient_op_map))) return WalkResult::interrupt(); - } - if (llvm::isa(op) || - llvm::isa(op)) { + if (llvm::isa(op)) if (failed(ExtractEmbeddingAttribute(op, enqueue_op_map))) return WalkResult::interrupt(); - } + return WalkResult::advance(); }); - if (walk_result.wasInterrupted()) return failure(); - - return success(); + return failure(walk_result.wasInterrupted()); } // Updates the operand of TPU embedding enqueue ops depending on whether @@ -89,16 +85,16 @@ LogicalResult FindTPUEmbeddingOps( // training mode. As so, correctly feed in `then` branch value of SelectV2 // operand as inputs to the TPU embedding enqueue ops. LogicalResult UpdateEmbeddingEnqueueOpInput( - const std::map& enqueue_op_map, - const std::map& recv_activation_op_map, - const std::map& send_gradient_op_map) { + const llvm::StringMap& enqueue_op_map, + const llvm::StringMap& recv_activation_op_map, + const llvm::StringMap& send_gradient_op_map) { for (const auto& it : enqueue_op_map) { - const auto embedding_attr = it.first; + const auto& embedding_attr = it.getKey(); Operation* embedding_op = it.second; if (!recv_activation_op_map.count(embedding_attr)) - return embedding_op->emitOpError( - "TPU embedding enqueue op must have corresponding " - "RecvTPUEmbeddingActivations op"); + return embedding_op->emitOpError() + << "must have a corresponding '" + << TF::RecvTPUEmbeddingActivationsOp::getOperationName() << "' op"; // TPU Embedding enqueue ops take different inputs depending on whether // graph is in training mode or in eval/prediction mode. The inputs to the @@ -129,18 +125,19 @@ void TPUUpdateEmbeddingEnqueueOpInputs::runOnFunction() { // `_tpu_embedding_layer` attribute along with corresponding string attribute. // Store all tpu embedding layer related ops with value of // `_tpu_embedding_layer` attribute as map key. - std::map enqueue_op_map; - std::map recv_activation_op_map; - std::map send_gradient_op_map; + llvm::StringMap enqueue_op_map; + llvm::StringMap recv_activation_op_map; + llvm::StringMap send_gradient_op_map; if (failed(FindTPUEmbeddingOps(func_op, &enqueue_op_map, &recv_activation_op_map, &send_gradient_op_map))) return signalPassFailure(); if (enqueue_op_map.size() != recv_activation_op_map.size()) { - func_op.emitError( - "Number of embedding enqueue ops must match the number " - "of RecvTPUEmbeddingActivation op"); + func_op.emitError() << "expects the number of embedding enqueue ops to " + "match the number of '" + << TF::RecvTPUEmbeddingActivationsOp::getOperationName() + << "' ops"; return signalPassFailure(); } From 546df2593f7df8299eb6f72d7497c09d22436db3 Mon Sep 17 00:00:00 2001 From: Peng Wang Date: Wed, 8 Jul 2020 14:54:23 -0700 Subject: [PATCH 85/88] [TF-numpy] Re-enables signature check against numpy, and fixes problems found by the check. PiperOrigin-RevId: 320271282 Change-Id: Iea02c7dd8802b6b4fa7e879f3c2d73bdaf7a92e3 --- .../python/ops/numpy_ops/np_array_ops.py | 19 +++--- .../python/ops/numpy_ops/np_math_ops.py | 4 +- tensorflow/python/ops/numpy_ops/np_utils.py | 58 ++++++++++++------- .../python/ops/numpy_ops/np_utils_test.py | 49 ++++++++++++---- 4 files changed, 90 insertions(+), 40 deletions(-) diff --git a/tensorflow/python/ops/numpy_ops/np_array_ops.py b/tensorflow/python/ops/numpy_ops/np_array_ops.py index efc6020f070..0bf03a0bd93 100644 --- a/tensorflow/python/ops/numpy_ops/np_array_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_array_ops.py @@ -472,11 +472,11 @@ def cumsum(a, axis=None, dtype=None): # pylint: disable=missing-docstring @np_utils.np_doc('imag') -def imag(a): - a = asarray(a) - # TODO(srbs): np.imag returns a scalar if a is a scalar, whereas we always +def imag(val): + val = asarray(val) + # TODO(srbs): np.imag returns a scalar if `val` is a scalar, whereas we always # return an ndarray. - return np_utils.tensor_to_ndarray(math_ops.imag(a.data)) + return np_utils.tensor_to_ndarray(math_ops.imag(val.data)) _TO_INT_ = 0 @@ -874,16 +874,17 @@ setattr(np_arrays.ndarray, 'reshape', _reshape_method_wrapper) @np_utils.np_doc('pad') -def pad(ary, pad_width, mode, constant_values=0): +def pad(array, pad_width, mode, **kwargs): # pylint: disable=redefined-outer-name """Only supports modes 'constant', 'reflect' and 'symmetric' currently.""" + constant_values = kwargs.get('constant_values', 0) if not (mode == 'constant' or mode == 'reflect' or mode == 'symmetric'): raise ValueError('Unsupported padding mode: ' + mode) mode = mode.upper() - ary = asarray(ary) + array = asarray(array) pad_width = asarray(pad_width, dtype=dtypes.int32) return np_utils.tensor_to_ndarray( array_ops.pad( - tensor=ary.data, + tensor=array.data, paddings=pad_width.data, mode=mode, constant_values=constant_values)) @@ -959,8 +960,8 @@ def ndim(a): @np_utils.np_doc('isscalar') -def isscalar(a): - return ndim(a) == 0 +def isscalar(num): + return ndim(num) == 0 def _boundaries_to_sizes(a, boundaries, axis): diff --git a/tensorflow/python/ops/numpy_ops/np_math_ops.py b/tensorflow/python/ops/numpy_ops/np_math_ops.py index 2b6247cfd81..fd7ba5f94f7 100644 --- a/tensorflow/python/ops/numpy_ops/np_math_ops.py +++ b/tensorflow/python/ops/numpy_ops/np_math_ops.py @@ -1348,7 +1348,9 @@ def meshgrid(*xi, **kwargs): return outputs -@np_utils.np_doc('einsum') +# Uses np_doc_only here because np.einsum (in 1.16) doesn't have argument +# `subscripts`, even though the doc says it has. +@np_utils.np_doc_only('einsum') def einsum(subscripts, *operands, **kwargs): # pylint: disable=missing-docstring casting = kwargs.get('casting', 'safe') optimize = kwargs.get('optimize', False) diff --git a/tensorflow/python/ops/numpy_ops/np_utils.py b/tensorflow/python/ops/numpy_ops/np_utils.py index db0e6d9e760..fe290e2c5ef 100644 --- a/tensorflow/python/ops/numpy_ops/np_utils.py +++ b/tensorflow/python/ops/numpy_ops/np_utils.py @@ -20,6 +20,7 @@ from __future__ import division from __future__ import print_function import inspect +import os import numpy as np from tensorflow.python.framework import dtypes @@ -227,6 +228,19 @@ def _np_doc_helper(f, np_f, np_fun_name=None, unsupported_params=None): return doc +_is_sig_mismatch_an_error = ( + os.getenv('TF_NP_SIG_MISMATCH_IS_ERROR', 'False') in ('True', 'true', '1')) + + +def is_sig_mismatch_an_error(): + return _is_sig_mismatch_an_error + + +def set_is_sig_mismatch_an_error(value): + global _is_sig_mismatch_an_error + _is_sig_mismatch_an_error = value + + def np_doc(np_fun_name, np_fun=None, export=True): """Attachs numpy docstring to a function. @@ -254,26 +268,30 @@ def np_doc(np_fun_name, np_fun=None, export=True): sig = inspect.signature(f) except ValueError: sig = None - # TODO(wangpeng): Enable this. - # Looks like this may not work with different versions of numpy. - # if sig is not None: - # for name, param in sig.parameters.items(): - # np_param = np_sig.parameters.get(name) - # if np_param is None: - # raise TypeError('Cannot find parameter "%s" in the numpy - # function\'s ' 'signature' % name) - # if not _is_compatible_param_kind(param.kind, np_param.kind): - # raise TypeError( - # 'Parameter "%s" is of kind %s while in numpy it is of ' - # 'kind %s' % (name, param.kind, np_param.kind)) - # has_default = (param.default != inspect.Parameter.empty) - # np_has_default = (np_param.default != inspect.Parameter.empty) - # if has_default != np_has_default: - # raise TypeError('Parameter "%s" should%s have a default value' % - # (name, '' if np_has_default else ' not')) - # for name in np_sig.parameters: - # if name not in sig.parameters: - # unsupported_params.append(name) + if sig is not None: + for name, param in sig.parameters.items(): + np_param = np_sig.parameters.get(name) + if np_param is None: + if is_sig_mismatch_an_error(): + raise TypeError( + 'Cannot find parameter "%s" in the numpy function\'s ' + 'signature (which has these parameters: %s)' % + (name, list(np_sig.parameters.keys()))) + else: + continue + if (is_sig_mismatch_an_error() and + not _is_compatible_param_kind(param.kind, np_param.kind)): + raise TypeError( + 'Parameter "%s" is of kind %s while in numpy it is of ' + 'kind %s' % (name, param.kind, np_param.kind)) + has_default = (param.default != inspect.Parameter.empty) + np_has_default = (np_param.default != inspect.Parameter.empty) + if is_sig_mismatch_an_error() and has_default != np_has_default: + raise TypeError('Parameter "%s" should%s have a default value' % + (name, '' if np_has_default else ' not')) + for name in np_sig.parameters: + if name not in sig.parameters: + unsupported_params.append(name) f.__doc__ = _np_doc_helper( f, np_fun, diff --git a/tensorflow/python/ops/numpy_ops/np_utils_test.py b/tensorflow/python/ops/numpy_ops/np_utils_test.py index 62bd427abcf..8a217ab5b5d 100644 --- a/tensorflow/python/ops/numpy_ops/np_utils_test.py +++ b/tensorflow/python/ops/numpy_ops/np_utils_test.py @@ -38,6 +38,8 @@ class UtilsTest(test.TestCase): expected = """TensorFlow variant of `numpy.np_fun`. +Unsupported arguments: `x`. + f docstring. """ @@ -56,35 +58,62 @@ f docstring. """ self.assertEqual(expected, f.__doc__) - def testNpDocErrors(self): + # pylint: disable=unused-variable + def testSigMismatchIsError(self): + """Tests that signature mismatch is an error (when configured so).""" + if not np_utils._supports_signature(): + self.skipTest('inspect.signature not supported') - self.skipTest('Enable once np signature checking is done.') - # if not np_utils._supports_signature(): - # self.skipTest("inspect.signature not supported") + old_flag = np_utils.is_sig_mismatch_an_error() + np_utils.set_is_sig_mismatch_an_error(True) def np_fun(x, y=1, **kwargs): return - # pylint: disable=unused-variable with self.assertRaisesRegex(TypeError, 'Cannot find parameter'): - @np_utils.np_doc(None, np_fun=np_fun) def f1(a): return with self.assertRaisesRegex(TypeError, 'is of kind'): - @np_utils.np_doc(None, np_fun=np_fun) def f2(x, kwargs): return - with self.assertRaisesRegex(TypeError, - 'Parameter "y" should have a default value'): - + with self.assertRaisesRegex( + TypeError, 'Parameter "y" should have a default value'): @np_utils.np_doc(None, np_fun=np_fun) def f3(x, y): return + np_utils.set_is_sig_mismatch_an_error(old_flag) + + def testSigMismatchIsNotError(self): + """Tests that signature mismatch is not an error (when configured so).""" + old_flag = np_utils.is_sig_mismatch_an_error() + np_utils.set_is_sig_mismatch_an_error(False) + + def np_fun(x, y=1, **kwargs): + return + + # The following functions all have signature mismatches, but they shouldn't + # throw errors when is_sig_mismatch_an_error() is False. + + @np_utils.np_doc(None, np_fun=np_fun) + def f1(a): + return + + def f2(x, kwargs): + return + + @np_utils.np_doc(None, np_fun=np_fun) + def f3(x, y): + return + + np_utils.set_is_sig_mismatch_an_error(old_flag) + + # pylint: enable=unused-variable + if __name__ == '__main__': test.main() From e242717155f25dbbd030a8ee0fda5172781ee186 Mon Sep 17 00:00:00 2001 From: Pavithra Vijay Date: Wed, 8 Jul 2020 15:02:57 -0700 Subject: [PATCH 86/88] Update v1 only training/adagrad_test with proper reason. PiperOrigin-RevId: 320273086 Change-Id: Ibcd14926a7e8acb457347501a50d32c1973e1bc3 --- tensorflow/python/training/adagrad_test.py | 18 +++++++++--------- 1 file changed, 9 insertions(+), 9 deletions(-) diff --git a/tensorflow/python/training/adagrad_test.py b/tensorflow/python/training/adagrad_test.py index 25fbec5eeec..6a735fa76b2 100644 --- a/tensorflow/python/training/adagrad_test.py +++ b/tensorflow/python/training/adagrad_test.py @@ -96,7 +96,7 @@ class AdagradOptimizerTest(test.TestCase): def testBasicLocked(self): self.doTestBasic(use_locking=True) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testMinimizeSparseResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): @@ -117,7 +117,7 @@ class AdagradOptimizerTest(test.TestCase): self.evaluate(var0), atol=0.01) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testTensorLearningRate(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): @@ -144,7 +144,7 @@ class AdagradOptimizerTest(test.TestCase): np.array([2.715679168701172, 3.715679168701172]), self.evaluate(var1)) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testSparseBasic(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): @@ -176,7 +176,7 @@ class AdagradOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType( np.array([[3.0], [3.715679168701172]]), self.evaluate(var1)) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testSparseRepeatedIndices(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): @@ -207,7 +207,7 @@ class AdagradOptimizerTest(test.TestCase): self.assertAllClose(aggregated_update_var, self.evaluate(repeated_index_update_var)) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testSparseRepeatedIndicesResourceVariable(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): @@ -232,7 +232,7 @@ class AdagradOptimizerTest(test.TestCase): self.assertAllCloseAccordingToType( self.evaluate(var_repeated), self.evaluate(var_aggregated)) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testSparseStability(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): @@ -269,7 +269,7 @@ class AdagradOptimizerTest(test.TestCase): -0.01029443 ]]), self.evaluate(var0)) - @test_util.run_deprecated_v1 + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testSharing(self): for dtype in [dtypes.half, dtypes.float32, dtypes.float64]: with self.cached_session(): @@ -306,7 +306,7 @@ class AdagradOptimizerTest(test.TestCase): np.array([2.715679168701172, 3.715679168701172]), self.evaluate(var1)) - @test_util.run_v1_only("b/120545219") + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testDynamicShapeVariable_Ok(self): with self.cached_session(): v = variable_scope.get_variable("v", initializer=constant_op.constant(1.), @@ -315,7 +315,7 @@ class AdagradOptimizerTest(test.TestCase): # Creating optimizer should cause no exception. adagrad.AdagradOptimizer(3.0, initial_accumulator_value=0.1) - @test_util.run_v1_only("b/120545219") + @test_util.run_v1_only("train.AdagradOptimizer is V1 only API.") def testDynamicShapeVariableWithCallableInit(self): var0 = variable_scope.get_variable("var0", initializer=constant_op.constant(1.), From a22b28a22a97597fc46b233d910005255178dea5 Mon Sep 17 00:00:00 2001 From: "A. Unique TensorFlower" Date: Wed, 8 Jul 2020 15:22:14 -0700 Subject: [PATCH 87/88] Fix typo, AddIntemediateTensorsToFusedOp to AddIntermediateTensorsToFusedOp. PiperOrigin-RevId: 320276937 Change-Id: Ic6765880c25af50410c46f8f1ee164653b47a4ec --- tensorflow/lite/tools/optimize/quantization_wrapper.cc | 2 +- .../lite/tools/optimize/quantization_wrapper_utils.cc | 2 +- tensorflow/lite/tools/optimize/quantization_wrapper_utils.h | 2 +- .../lite/tools/optimize/quantization_wrapper_utils_test.cc | 6 +++--- 4 files changed, 6 insertions(+), 6 deletions(-) diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper.cc b/tensorflow/lite/tools/optimize/quantization_wrapper.cc index 56416c894ea..f8694338cf4 100644 --- a/tensorflow/lite/tools/optimize/quantization_wrapper.cc +++ b/tensorflow/lite/tools/optimize/quantization_wrapper.cc @@ -27,7 +27,7 @@ bool CreateModelForCalibration(const std::string& input_path, return false; } flatbuffers::FlatBufferBuilder builder; - if (AddIntemediateTensorsToFusedOp(&builder, &model) != kTfLiteOk) { + if (AddIntermediateTensorsToFusedOp(&builder, &model) != kTfLiteOk) { return false; } return WriteFile(output_path, builder.GetBufferPointer(), builder.GetSize()); diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc b/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc index b8ed852d14b..753cf99375a 100644 --- a/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc +++ b/tensorflow/lite/tools/optimize/quantization_wrapper_utils.cc @@ -66,7 +66,7 @@ TfLiteStatus LoadModel(const string& path, ModelT* model) { return kTfLiteOk; } -TfLiteStatus AddIntemediateTensorsToFusedOp( +TfLiteStatus AddIntermediateTensorsToFusedOp( flatbuffers::FlatBufferBuilder* builder, ModelT* model) { // Return early if the model already has intermediate tensors. if (IntermediateTensorExists(model)) { diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper_utils.h b/tensorflow/lite/tools/optimize/quantization_wrapper_utils.h index 744b0dfc0ac..e58b46fd2bd 100644 --- a/tensorflow/lite/tools/optimize/quantization_wrapper_utils.h +++ b/tensorflow/lite/tools/optimize/quantization_wrapper_utils.h @@ -28,7 +28,7 @@ TfLiteStatus LoadModel(const string& path, ModelT* model); // Going through the model and add intermediates tensors if the ops have any. // Returns early if the model has already intermediate tensors. This is to // support cases where a model is initialized multiple times. -TfLiteStatus AddIntemediateTensorsToFusedOp( +TfLiteStatus AddIntermediateTensorsToFusedOp( flatbuffers::FlatBufferBuilder* builder, ModelT* model); // Write model to a given location. diff --git a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc index d3836e3463b..371de7d04bc 100644 --- a/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc +++ b/tensorflow/lite/tools/optimize/quantization_wrapper_utils_test.cc @@ -59,7 +59,7 @@ TEST(LstmPreprocess, Add2Tensors) { // Add 2 tensors. flatbuffers::FlatBufferBuilder builder; - tflite::optimize::AddIntemediateTensorsToFusedOp(&builder, model.get()); + tflite::optimize::AddIntermediateTensorsToFusedOp(&builder, model.get()); // Verify results. EXPECT_EQ(model->operator_codes.size(), 1); @@ -84,8 +84,8 @@ TEST(LstmPreprocess, Add2Tensors) { EXPECT_THAT(model->subgraphs[0]->operators[0]->intermediates, ElementsAreArray({21, 22, 23, 24, 25})); - // Call AddIntemediateTensorsToFusedOp again and expect no change in model. - tflite::optimize::AddIntemediateTensorsToFusedOp(&builder, model.get()); + // Call AddIntermediateTensorsToFusedOp again and expect no change in model. + tflite::optimize::AddIntermediateTensorsToFusedOp(&builder, model.get()); // Verify results. EXPECT_EQ(model->operator_codes.size(), 1); From 486ac1e10a51fb06f45fec29fcbbd33ba7b5694e Mon Sep 17 00:00:00 2001 From: Gunhan Gulsoy Date: Wed, 8 Jul 2020 15:22:39 -0700 Subject: [PATCH 88/88] Update TF's six dependency. PiperOrigin-RevId: 320277020 Change-Id: I84b559c0aaff61e6745e141235d4afeed5e0cafb --- tensorflow/workspace.bzl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/tensorflow/workspace.bzl b/tensorflow/workspace.bzl index 57d5462a099..4c8d94fa1a9 100755 --- a/tensorflow/workspace.bzl +++ b/tensorflow/workspace.bzl @@ -434,12 +434,12 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""): tf_http_archive( name = "six_archive", build_file = clean_dep("//third_party:six.BUILD"), - sha256 = "d16a0141ec1a18405cd4ce8b4613101da75da0e9a7aec5bdd4fa804d0e0eba73", - strip_prefix = "six-1.12.0", + sha256 = "30639c035cdb23534cd4aa2dd52c3bf48f06e5f4a941509c8bafd8ce11080259", + strip_prefix = "six-1.15.0", system_build_file = clean_dep("//third_party/systemlibs:six.BUILD"), urls = [ - "https://storage.googleapis.com/mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", - "https://pypi.python.org/packages/source/s/six/six-1.12.0.tar.gz", + "https://storage.googleapis.com/mirror.tensorflow.org/pypi.python.org/packages/source/s/six/six-1.15.0.tar.gz", + "https://pypi.python.org/packages/source/s/six/six-1.15.0.tar.gz", ], )