Merge branch 'master' into addsub_16x8

This commit is contained in:
Elena Zhelezina 2020-06-24 19:12:39 +01:00 committed by GitHub
commit 1bcc3bc41c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
110 changed files with 2320 additions and 1616 deletions

View File

@ -54,7 +54,7 @@ Status ProcessInputs(
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node;
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node;
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx),

View File

@ -19,6 +19,7 @@ tf_cc_shared_object(
cc_library(
name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"],
hdrs = ["gcs_filesystem.h"],
copts = select({
"//conditions:default": [],
"//tensorflow:windows": get_win_copts(),
@ -55,14 +56,9 @@ tf_cc_test(
"notap",
],
deps = [
":gcs_helper",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
":gcs_filesystem_impl",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
],
)

View File

@ -12,26 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <stdlib.h>
#include <string.h>
#include <fstream>
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h"
#ifdef TF_GCS_FILESYSTEM_TEST
// For testing purpose, we expose some functions.
#define TF_STATIC
#else
// Otherwise, we don't expose any symbol.
#define TF_STATIC static
#endif
// Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage;
@ -48,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); }
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
char** bucket, char** object, TF_Status* status) {
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT,
@ -130,7 +120,7 @@ namespace tf_read_only_memory_region {
namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
TF_STATIC void Init(TF_Filesystem* filesystem, TF_Status* status) {
void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient();
if (!client) {
@ -143,14 +133,14 @@ TF_STATIC void Init(TF_Filesystem* filesystem, TF_Status* status) {
TF_SetStatus(status, TF_OK, "");
}
static void Cleanup(TF_Filesystem* filesystem) {
void Cleanup(TF_Filesystem* filesystem) {
plugin_memory_free(filesystem->plugin_filesystem);
}
// TODO(vnvo2409): Implement later
static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
ParseGCSPath(path, false, &bucket, &object, status);
@ -166,8 +156,8 @@ static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_SetStatus(status, TF_OK, "");
}
static void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) {
char* bucket;
char* object;
ParseGCSPath(path, false, &bucket, &object, status);

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** object, TF_Status* status);
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_

View File

@ -12,18 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
// Forward declaration
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
}
namespace tensorflow {
namespace {
@ -38,7 +34,7 @@ class GCSFilesystemTest : public ::testing::Test {
}
void TearDown() override {
TF_DeleteStatus(status_);
// TODO(vnvo2409): Add filesystem cleanup
tf_gcs_filesystem::Cleanup(filesystem_);
delete filesystem_;
}

View File

@ -138,6 +138,11 @@ bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
}
// Return true when the given element_type is I64.
bool IsI64Type(Type element_type) {
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType());
@ -174,7 +179,8 @@ bool VerifySubOpShapeConstraints(SubOp op) {
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
IsI64Type(element_type) || IsQUI8Type(element_type) ||
IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5);

View File

@ -2864,11 +2864,11 @@ def TFL_SubOp : TFL_Op<"sub", [
}];
let arguments = (
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
let hasFolder = 1;

View File

@ -9,6 +9,15 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: return
}
func @sub(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>) -> tensor<1xi64> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
return %0: tensor<1xi64>
// CHECK-LABEL: sub
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi64>
// CHECK: return
}
// CHECK-LABEL: testAddHighDimsHaveSameShape
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}

View File

@ -269,6 +269,14 @@ func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
return %0#0 : tensor<? x i32>
}
// CHECK-LABEL: testSubInt64
func @testSubInt64(tensor<? x i64>, tensor<? x i64>) -> tensor<? x i64> {
^bb0(%arg0: tensor<? x i64>, %arg1: tensor<? x i64>):
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i64>
return %0#0 : tensor<? x i64>
}
// CHECK-LABEL: testMul
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):

View File

@ -346,6 +346,7 @@ func @replication(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<f32>) ->
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
// CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
// CHECK-NOT: _replicated_input_indices
// CHECK-SAME: n = 2 : i32
// CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( {
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
@ -382,6 +383,46 @@ func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<
// CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
// CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}}
// CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _replicated_input_indices = [0, 1, 2, -1, -1, -1]
// Test TPUReplicatedInputs with non contiguous `index` attributes.
// CHECK-LABEL: func @non_contigous_indices
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>)
func @non_contigous_indices(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 8 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
%1 = "tf.TPUReplicatedInput"(%arg1, %arg1) : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opB"(%1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
%2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opC"(%2) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _replicated_input_indices = [2, 8, -1]
// Test that the `is_mirrored_variable` attribute is preserved in the
// tf_device.replicate op.
// CHECK-LABEL: func @mirrored_variables
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_2:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_3:.*]]: tensor<!tf.resource<tensor<32xf32>>>)
func @mirrored_variables(%arg0: tensor<!tf.resource<tensor<32xf32>>>, %arg1: tensor<!tf.resource<tensor<32xf32>>>, %arg2: tensor<!tf.resource<tensor<32xf32>>>, %arg3: tensor<!tf.resource<tensor<32xf32>>>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
%1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device"} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _mirrored_variable_indices = [1]
// CHECK-SAME: _replicated_input_indices = [0, 1]
// -----
@ -407,8 +448,10 @@ func @bad_num_replicas() {
return
}
// -----
// Test that functions without TPUReplicateMetadata op are skipped without
// error
// CHECK-LABEL: func @missing_metadata_op
@ -483,22 +526,9 @@ func @leftover_replicated_output(%arg0: tensor<i1>) {
// -----
// Test bad TPUReplicatedInput positive `index` attribute.
func @bad_positive_index_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got 1}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test bad TPUReplicatedInput negative `index` attribute.
func @bad_negative_index_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got -2}}
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires index to be at least -1, but got -2}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
@ -509,33 +539,12 @@ func @bad_negative_index_input(%arg0: tensor<i1>) {
// -----
// Test TPUReplicatedInput with conflicting `index` attribute. This will result
// in gaps in the TPUReplicatedInput ordering.
// Test TPUReplicatedInput with conflicting `index` attribute.
func @input_index_gaps(%arg0: tensor<i1>) {
// expected-error@+1 {{failed to sort 'tf.TPUReplicatedInput' ops, gap(s) found in indices}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires indices to be unique, but found multiple 'tf.TPUReplicatedInput' ops with index 1}}
%1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test that the `is_mirrored_variable` attribute is preserved in the
// tf_device.replicate op.
// CHECK-LABEL: func @mirrored_variables
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_2:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_3:.*]]: tensor<!tf.resource<tensor<32xf32>>>)
func @mirrored_variables(%arg0: tensor<!tf.resource<tensor<32xf32>>>, %arg1: tensor<!tf.resource<tensor<32xf32>>>, %arg2: tensor<!tf.resource<tensor<32xf32>>>, %arg3: tensor<!tf.resource<tensor<32xf32>>>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
%1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device"} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _mirrored_variable_indices = [1]

View File

@ -9,7 +9,7 @@
// padding_arg_index: 1
// CHECK-LABEL: func @single_arg_single_shape
func @single_arg_single_shape(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func0, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -36,7 +36,7 @@ func @func0(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// padding_arg_index: 2
// CHECK-LABEL: func @single_arg_multiple_shapes
func @single_arg_multiple_shapes(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>) {_replicated_input_indices = [0, 1, 2], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1, %ri_2) {func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -68,7 +68,7 @@ func @func1(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// padding_arg_index: 3
// CHECK-LABEL: func @multiple_args
func @multiple_args(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>, [%arg0, %arg0] as %ri_4: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>, [%arg0, %arg0] as %ri_4: tensor<i1>) {_replicated_input_indices = [0, 1, 2, 3, 4], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -89,7 +89,7 @@ func @func2(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tens
// padding_arg_index: 1
// CHECK-LABEL: func @remap_indices
func @remap_indices(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func3, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -124,7 +124,7 @@ func @func4(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// Test encapsulated function is not modified when there are no padding maps.
// CHECK-LABEL: func @no_padding_map
func @no_padding_map(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func5} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -140,7 +140,7 @@ func @func5(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// Test encapsulated function is not modified when padding maps is empty.
// CHECK-LABEL: func @empty_padding_map
func @empty_padding_map(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func6, padding_map = []} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -161,7 +161,7 @@ func @func6(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// padding_arg_index: 1
// CHECK-LABEL: func @unused_padding_map
func @unused_padding_map(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1) {func = @func7, padding_map = ["\10\02\18\01"]} : (tensor<i1>) -> ()
tf_device.return
}
@ -187,7 +187,7 @@ func @func7(%arg0: tensor<i1>) {
// shape_index: 2
// padding_arg_index: 3
func @missing_padding_arg(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>) {n = 2 : i32} {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>) {_replicated_input_indices = [0, 1, 2, 3], n = 2 : i32} {
// expected-warning@+1 {{bad 'padding_map' attribute at index 0, unused padding_arg_index 1}}
"tf_device.cluster_func"(%ri_0, %ri_2, %ri_3) {func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return
@ -201,11 +201,55 @@ func @func8(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
return
}
// Test tf_device.replicate with missing _replicated_input_indices does no
// transformation.
//
// Padding map "\10\02\18\01":
// arg_index: 0
// shape_index: 2
// padding_arg_index: 1
// CHECK-LABEL: func @missing_replicated_input_indices
func @missing_replicated_input_indices(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func9, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func @func9
// CHECK-NOT: xla_hlo.padding_map
func @func9(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
// Test single argument with padding map lifted to associated encapsulated
// function.
//
// Padding map "\08\08\10\06\18\02"
// arg_index: 8
// shape_index: 6
// padding_arg_index: 2
// CHECK-LABEL: func @non_contigous_indices
func @non_contigous_indices(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [2, 8], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func10, padding_map = ["\08\08\10\06\18\02"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func @func10
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}})
func @func10(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
// -----
// Test bad padding map attribute (not an array).
func @bad_padding_map() {
tf_device.replicate {n = 2 : i32} {
tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op requires 'padding_map' array attribute}}
"tf_device.cluster_func"() {func = @_func, padding_map = 0 : i32} : () -> ()
tf_device.return
@ -221,7 +265,7 @@ func @_func() {
// Test bad padding map attribute (element in array is not a string).
func @bad_padding_map_element() {
tf_device.replicate {n = 2 : i32} {
tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, not a string}}
"tf_device.cluster_func"() {func = @_func, padding_map = [0 : i32]} : () -> ()
tf_device.return
@ -237,7 +281,7 @@ func @_func() {
// Test unparsable padding map.
func @bad_padding_map_proto() {
tf_device.replicate {n = 2 : i32} {
tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}}
"tf_device.cluster_func"() {func = @_func, padding_map = ["z"]} : () -> ()
tf_device.return
@ -258,8 +302,8 @@ func @_func() {
// shape_index: 2
// padding_arg_index: 1
func @negative_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}}
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be nonnegative, but got -1}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -272,27 +316,6 @@ func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// -----
// Test out of bound arg index.
//
// Padding map "\08\02\10\02\18\01":
// arg_index: 2
// shape_index: 2
// padding_arg_index: 1
func @bad_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
// -----
// Test negative padding arg index.
//
// Padding map "\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01":
@ -300,8 +323,8 @@ func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// shape_index: 2
// padding_arg_index: -1
func @negative_padding_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}}
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be nonnegative, but got -1}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
@ -311,24 +334,3 @@ func @negative_padding_arg_index(%arg0: tensor<i1>) {
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
// -----
// Test out of bound padding arg index.
//
// Padding map "\08\01\10\02\18\02":
// arg_index: 1
// shape_index: 2
// padding_arg_index: 2
func @bad_padding_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}

View File

@ -2,460 +2,455 @@
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
func @missing_outside_compilation_attribute() -> () {
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
// Tests that TPU cluster with no outside compilation does not generate parallel_execute.
// -----
// CHECK-LABEL: func @no_outside_compilation
func @no_outside_compilation() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// Tests that TPU cluster with no outside compilation does not generate parallel_execute.
// CHECK-NOT: "tf_device.parallel_execute"
// CHECK-LABEL: func @no_outside_compilation
func @no_outside_compilation() -> tensor<?xi32> {
%0 = "tf_device.cluster"() ( {
%1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
return %0 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
// CHECK-NOT: "tf_device.parallel_execute"
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
// CHECK-LABEL: func @nodep_single_outside_compilation
func @nodep_single_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK: cluster_attr = "cluster_attr"
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies.
// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation
func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: "tf.D"
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.E"
// CHECK: cluster_attr = "cluster_attr"
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.E"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// Tests extraction of a multiple outside compiled clusters with no input or output dependecies.
// CHECK-LABEL: func @nodep_multiple_outside_compilation
func @nodep_multiple_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-COUNT-2: "tf_device.launch"
// CHECK: "tf_device.cluster"
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
"tf.E"() : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
// CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
// CHECK-LABEL: func @nodep_single_outside_compilation
func @nodep_single_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%3 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %3 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
"tf.C"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies.
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xf32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
// CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2, %3 = "tf_device.cluster"() ( {
%4 = "tf.A"() : () -> tensor<?xf32>
// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation
func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: "tf.C"
// CHECK-NEXT: "tf.D"
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.E"
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%5 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.E"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return
}
return %1 : tensor<?xf32>
}
// Tests extraction of a multiple outside compiled clusters with no input or output dependecies.
// Tests extraction of a single outside compiled cluster with single device->host input.
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
// CHECK-LABEL: func @nodep_multiple_outside_compilation
func @nodep_multiple_outside_compilation() -> () {
// CHECK: "tf_device.parallel_execute"
// CHECK-COUNT-2: "tf_device.launch"
// CHECK: "tf_device.cluster"
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> ()
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
"tf.E"() : () -> ()
tf_device.return
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
// Tests extraction of a single outside compiled cluster with single host->device output.
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
// CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
"tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%3 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %3 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation
func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"()
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
// Tests extraction of a single outside compiled cluster host output returned by TPU cluster.
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xf32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
// CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2, %3 = "tf_device.cluster"() ( {
%4 = "tf.A"() : () -> tensor<?xf32>
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%5 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xf32>, tensor<?xi32>)
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
}
// CHECK-LABEL: func @return_host_output_outside_compilation
func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: tf_device.return %[[HOST_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xf32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with single device->host input.
// Tests extraction of a single outside compiled cluster with single input/output.
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation
func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with single host->device output.
// Tests extraction of a single outside compiled cluster with multiple input/output.
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation
func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"()
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
%5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %8 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster host output returned by TPU cluster.
// Tests extraction of a multiple outside compiled clusters with input/output.
// CHECK-LABEL: func @return_host_output_outside_compilation
func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: tf_device.return %[[HOST_OUTPUT]]
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
tf_device.return %4 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation
func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
// CHECK: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>)
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %7 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with single input/output.
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation
func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @mixed_input_single_outside_compilation
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with multiple input/output.
// Tests extraction of a multiple outside compiled clusters with single device->host input.
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
%5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %8 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
// CHECK: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a multiple outside compiled clusters with input/output.
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation
func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]])
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
// CHECK: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]])
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>)
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %7 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
// Tests only directly used results of tpu cluster are remapped with
// parallel_execute.
// CHECK-LABEL: func @mixed_input_single_outside_compilation
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
// CHECK-LABEL: func @remapped_results
func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2:2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %2#1 : tensor<?xi32>
return %1 : tensor<?xi32>
}
// Tests extraction of a multiple outside compiled clusters with single device->host input.
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
// CHECK: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster2"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32>
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
tf_device.return %4 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch"
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
// CHECK: "tf_device.cluster"
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
// CHECK-SAME: key = "host_compute_channel_cluster1"
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"() : () -> (tensor<?xi32>)
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
// Tests only directly used results of tpu cluster are remapped with
// parallel_execute.
// CHECK-LABEL: func @remapped_results
func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
%2:2 = "tf_device.cluster"() ( {
%3 = "tf.A"() : () -> (tensor<?xi32>)
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %2#1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}
return %1 : tensor<?xi32>
}

View File

@ -22,6 +22,7 @@ limitations under the License.
// not have ops outside of the cluster that are both operands and results of the
// cluster. Note, this currently does not handle side effecting ops yet.
#include <algorithm>
#include <iterator>
#include <memory>
#include <tuple>
@ -29,6 +30,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h"
@ -59,6 +61,7 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
constexpr char kDeviceAttr[] = "device";
constexpr char kNameAttr[] = "name";
constexpr char kNumReplicasAttr[] = "num_replicas";
constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
constexpr char kBadTPUReplicateAttrMsg[] =
@ -261,33 +264,42 @@ void MovePrecedingClusterUsers(tf_device::ClusterOp cluster,
// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
// of -1 are always after ops with a non negative `index`, and an arbitrary
// ordering is used as there are no dependencies on their relative ordering.
// ordering is used as there are no dependencies on their relative ordering. If
// there are multiple `tf.TPUReplicatedInput` ops with the same non negative
// index or if indices are less than -1, an error will be returned.
LogicalResult SortTPUReplicatedInputsByIndex(
llvm::ArrayRef<Operation*> inputs,
llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
const int input_size = inputs.size();
sorted_inputs->resize(input_size, nullptr);
int last_index = input_size - 1;
llvm::SmallDenseSet<int64_t, 8> unique_indices;
for (Operation* input : inputs) {
int64_t index =
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getLimitedValue();
if (index >= input_size || index < -1)
return input->emitError() << "'" << input->getName().getStringRef()
<< "' index is not in range [-1, " << input_size
<< "), got " << index;
if (index == -1)
(*sorted_inputs)[last_index--] = input;
else
(*sorted_inputs)[index] = input;
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getSExtValue();
if (index < -1)
return input->emitOpError()
<< "requires index to be at least -1, but got " << index;
if (index == -1) continue;
if (!unique_indices.insert(index).second)
return input->emitOpError()
<< "requires indices to be unique, but found multiple '"
<< input->getName() << "' ops with index " << index;
}
if (llvm::any_of(*sorted_inputs, [](Operation* op) { return op == nullptr; }))
return inputs.front()->emitError()
<< "failed to sort '" << inputs.front()->getName().getStringRef()
<< "' ops, gap(s) found in indices";
// Sort all TPUReplicatedInputs by `index` attribute to have
// TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
// deterministically. If `index` attribute is -1, instead move them to the
// end.
sorted_inputs->assign(inputs.begin(), inputs.end());
std::stable_sort(
sorted_inputs->begin(), sorted_inputs->end(),
[](Operation* l, Operation* r) {
int64_t l_index =
llvm::cast<TF::TPUReplicatedInputOp>(l).index().getSExtValue();
int64_t r_index =
llvm::cast<TF::TPUReplicatedInputOp>(r).index().getSExtValue();
if (l_index == -1 && r_index != -1) return false;
if (r_index == -1 && l_index != -1) return true;
return l_index < r_index;
});
return success();
}
@ -315,6 +327,11 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
return failure();
// Index attribute value stored on TPUReplicatedInput op. These will be used
// later for dynamic padder.
llvm::SmallVector<int64_t, 8> replicated_input_indices;
bool has_replicated_input_index = false;
// Indices of the replicate op's arguments that are mirrored variables.
llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
@ -330,7 +347,14 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
replicated_inputs.push_back(
{input->getOperands(), input->getOperand(0).getType()});
if (llvm::cast<TF::TPUReplicatedInputOp>(input).is_mirrored_variable())
auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
int64_t tpu_replicated_input_index =
tpu_replicated_input.index().getSExtValue();
replicated_input_indices.push_back(tpu_replicated_input_index);
if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
if (tpu_replicated_input.is_mirrored_variable())
mirrored_variable_indices.push_back(pos_and_input.index());
}
@ -340,6 +364,10 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
cluster.getLoc(), num_replicas,
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
replicated_inputs, cluster.getResultTypes());
if (has_replicated_input_index)
replicate_op.setAttr(kReplicatedInputIndicesAttr,
builder.getI64ArrayAttr(replicated_input_indices));
if (!mirrored_variable_indices.empty())
replicate_op.setAttr(kMirroredVariableIndicesAttr,
builder.getI64ArrayAttr(mirrored_variable_indices));

View File

@ -40,6 +40,7 @@ limitations under the License.
namespace mlir {
namespace TFTPU {
constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
constexpr char kPaddingMapAttr[] = "padding_map";
// This pass remaps and assigns padding maps to an encapsulated function's
@ -56,14 +57,23 @@ struct TPUDynamicPaddingMapper
// Creates a mapping from replicated input index (in `tf_device.replicate` op)
// to `tf_device.cluster_func` operand index.
llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate) {
tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate,
ArrayAttr replicated_input_indices_attr) {
Block* replicate_block = &replicate.GetBody();
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands()))
if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>())
if (block_arg.getOwner() == replicate_block)
remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index();
for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>()) {
if (block_arg.getOwner() == replicate_block) {
int64_t replicated_input_index =
replicated_input_indices_attr[block_arg.getArgNumber()]
.cast<IntegerAttr>()
.getInt();
if (replicated_input_index != -1)
remapped_indices[replicated_input_index] = operand_and_idx.index();
}
}
}
return remapped_indices;
}
@ -73,16 +83,15 @@ llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
// indices. An error will be returned if an index is not found or parsing
// failed.
LogicalResult GetRemappedPaddings(
tf_device::ClusterFuncOp cluster_func, int num_replicated_args,
tf_device::ClusterFuncOp cluster_func,
const llvm::SmallDenseMap<int32_t, int32_t>& remapped_indices,
llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap>* remapped_paddings) {
auto bad_index_msg = [num_replicated_args](int32_t index,
llvm::StringRef arg_type,
int32_t arg_index) {
auto bad_index_msg = [](int32_t index, llvm::StringRef arg_type,
int32_t arg_index) {
return llvm::formatv(
"bad '{0}' attribute at index {1}, {2} must be in [0, {3}), got "
"{4}",
kPaddingMapAttr, index, arg_type, num_replicated_args, arg_index)
"bad '{0}' attribute at index {1}, {2} must be nonnegative, but "
"got {3}",
kPaddingMapAttr, index, arg_type, arg_index)
.str();
};
@ -111,12 +120,12 @@ LogicalResult GetRemappedPaddings(
kPaddingMapAttr, idx, padding.getValue()));
const int32_t arg_index = padding_proto.arg_index();
if (arg_index >= num_replicated_args || arg_index < 0)
if (arg_index < 0)
return cluster_func.emitOpError()
<< bad_index_msg(idx, "arg_index", arg_index);
const int32_t padding_arg_index = padding_proto.padding_arg_index();
if (padding_arg_index >= num_replicated_args || padding_arg_index < 0)
if (padding_arg_index < 0)
return cluster_func.emitOpError()
<< bad_index_msg(idx, "padding_arg_index", padding_arg_index);
@ -175,17 +184,21 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
// LaunchFunc is not replicated, there will be no padding.
if (!replicate) return success();
const int num_replicated_args = replicate.GetBody().getNumArguments();
auto func = symbol_table->lookup<FuncOp>(cluster_func.func());
if (!func) return success();
auto replicated_input_indices_attr =
replicate.getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
if (!replicated_input_indices_attr) return success();
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
GetRemappedReplicatedInputIndices(cluster_func, replicate);
GetRemappedReplicatedInputIndices(cluster_func, replicate,
replicated_input_indices_attr);
llvm::SmallVector<tensorflow::tpu::PaddingMap, 4> remapped_paddings;
if (failed(GetRemappedPaddings(cluster_func, num_replicated_args,
remapped_indices, &remapped_paddings)))
if (failed(GetRemappedPaddings(cluster_func, remapped_indices,
&remapped_paddings)))
return failure();
AnnotateFunctionArgumentsWithPaddings(func, remapped_paddings);

View File

@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir {
namespace TFTPU {
@ -91,13 +93,14 @@ void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
// Creates a `tf_device::LaunchOp` to wrap cluster ops.
tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
OpBuilder* builder, Operation* last_cluster_op) {
// TODO(b/154363171): Set the CPU device.
OpBuilder* builder, Operation* last_cluster_op,
llvm::StringRef host_device) {
// An empty string placeholder is used for the device as that will be later
// populated with the device of the associated TPUReplicateMetadata op.
llvm::SmallVector<Type, 8> result_types;
auto launch_op = builder->create<tf_device::LaunchOp>(
last_cluster_op->getLoc(), builder->getStringAttr(""), result_types);
last_cluster_op->getLoc(), builder->getStringAttr(host_device),
result_types);
launch_op.body().push_back(new Block);
@ -253,8 +256,9 @@ void MoveOutsideCompiledOps(
// Creates a `parallel_execute` op in place of launch with 'clusters` and
// 'launch` as regions.
void CreateParallelExecuteFromOutsideClusters(
tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) {
void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster,
const OutsideClusterMap& clusters,
llvm::StringRef host_device) {
OpBuilder builder(tpu_cluster);
// Create parallel_execute regions. The original TPU cluster computation
// is the extra region.
@ -269,8 +273,8 @@ void CreateParallelExecuteFromOutsideClusters(
Block& outside_block =
parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
builder.setInsertionPointToEnd(&outside_block);
tf_device::LaunchOp host_launch_op =
CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back());
tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster(
&builder, cluster_ops.back(), host_device);
// Determine if there are any inputs that are provided out of cluster.
auto external_inputs = GetExternalOperands(cluster_ops);
@ -307,8 +311,14 @@ void CreateParallelExecuteFromOutsideClusters(
}
void TPUExtractOutsideCompilation::runOnOperation() {
// Get runtime devices information from the closest parent module.
auto module = getOperation();
mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
return signalPassFailure();
auto extract_result =
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) {
module.walk([&](tf_device::ClusterOp tpu_cluster) {
OutsideClusterMap clusters;
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
&clusters)))
@ -316,7 +326,11 @@ void TPUExtractOutsideCompilation::runOnOperation() {
if (clusters.empty()) return WalkResult::advance();
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters);
std::string host_device;
tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
&host_device);
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters,
host_device);
return WalkResult::advance();
});

View File

@ -91,6 +91,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; }
const std::string& platform_name() const {
return pjrt_client_->platform_name();

View File

@ -34,13 +34,13 @@ END
description: <<END
The type of padding algorithm to use.
We specify the size-related attributes as:
The size-related attributes are specified as follows:
```python
ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
strides = [1, stride_planes, strides_rows, strides_cols, 1]
ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
strides = [1, stride_planes, strides_rows, strides_cols, 1]
```
END
}
summary: "Extract `patches` from `input` and put them in the \"depth\" output dimension. 3D extension of `extract_image_patches`."
summary: "Extract `patches` from `input` and put them in the `\"depth\"` output dimension. 3D extension of `extract_image_patches`."
}

View File

@ -9,10 +9,9 @@ def _lookup_file(filegroup, path):
return file
return None
def _gen_kernel_image_hdr_impl(ctx):
if not ctx.attr.gpu_archs:
fail("No GPU architecture specified, use --config=cuda or similar")
CubinInfo = provider(fields = ["cubins"])
def _gen_kernel_cubin_impl(ctx):
name = ctx.attr.name
tile_sizes = ctx.attr.tile_size.replace("x", ",")
cmd_args = []
@ -22,7 +21,6 @@ def _gen_kernel_image_hdr_impl(ctx):
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
cubins = []
images = []
for arch in ctx.attr.gpu_archs:
# TODO(b/152737872): 'compute_' should generate both SASS and PTX.
arch = arch.replace("compute_", "sm_")
@ -41,13 +39,36 @@ def _gen_kernel_image_hdr_impl(ctx):
mnemonic = "compile",
)
cubins.append(cubin)
return [CubinInfo(cubins = cubins)]
_gen_kernel_cubin_rule = rule(
implementation = _gen_kernel_cubin_impl,
attrs = {
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
"tile_size": attr.string(mandatory = True),
"same_shape": attr.string(),
"unroll_factors": attr.string(),
"gpu_archs": attr.string_list(mandatory = True),
"_tool": attr.label(
executable = True,
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
cfg = "host",
),
},
output_to_genfiles = True,
)
def _gen_kernel_image_hdr_impl(ctx):
images = []
for cubin in ctx.attr.input[CubinInfo].cubins:
arch = cubin.path.split(".")[-2]
images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
# Generate fatbin file from all cubins.
fatbin = ctx.actions.declare_file("%s.fatbin" % name)
fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
ctx.actions.run(
outputs = [fatbin],
inputs = cubins,
inputs = ctx.attr.input[CubinInfo].cubins,
executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
arguments = [
"--64",
@ -73,37 +94,31 @@ _gen_kernel_image_hdr_rule = rule(
implementation = _gen_kernel_image_hdr_impl,
output_to_genfiles = True,
attrs = {
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
"tile_size": attr.string(mandatory = True),
"same_shape": attr.string(),
"unroll_factors": attr.string(),
"input": attr.label(mandatory = True, providers = [CubinInfo]),
"out": attr.output(mandatory = True),
"symbol": attr.string(mandatory = True),
"gpu_archs": attr.string_list(mandatory = True),
"_cuda_root": attr.label(
default = Label("@local_config_cuda//cuda:cuda_root"),
),
"_tool": attr.label(
executable = True,
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
cfg = "host",
),
},
)
def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None, unroll_factors = None):
def _gen_kernel_image_hdr(name, mlir_op, tile_size, same_shape = None, unroll_factors = None):
"""Generates a C header with fatbin data from a Tensorflow op."""
if cuda_gpu_architectures():
_gen_kernel_image_hdr_rule(
name = name,
_gen_kernel_cubin_rule(
name = name + "_cubin",
mlir_op = mlir_op,
tile_size = tile_size,
same_shape = same_shape,
unroll_factors = unroll_factors,
gpu_archs = cuda_gpu_architectures(),
)
_gen_kernel_image_hdr_rule(
name = name,
input = ":" + name + "_cubin",
out = "%s.h" % name,
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
gpu_archs = cuda_gpu_architectures(),
tags = tags,
)
def _gen_mlir_op_impl(ctx):
@ -157,7 +172,6 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr
name = "{name}_{type}_kernel".format(name = name, type = type),
mlir_op = "{name}_{type}.mlir".format(name = name, type = type),
tile_size = tile_size,
tags = tags,
same_shape = same_shape,
unroll_factors = unroll_factors,
)

View File

@ -108,7 +108,7 @@ limitations under the License.
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 441 // Updated: 2020/6/23
#define TF_GRAPH_DEF_VERSION 442 // Updated: 2020/6/24
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
//

View File

@ -28,10 +28,12 @@ cc_library(
srcs = ["tpu_compile_op_common.cc"],
hdrs = ["tpu_compile_op_common.h"],
deps = [
":tpu_compile_op_options",
":tpu_compile_op_support",
":tpu_mesh_state_interface",
":tpu_program_group_interface",
":tpu_util",
":tpu_util_c_api_hdrs",
":tpu_util_hdrs",
"//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference",
@ -50,6 +52,7 @@ cc_library(
"//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
alwayslink = 1,

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h"
@ -28,8 +29,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h"
@ -518,5 +521,41 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
return Status::OK();
}
void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
CancellationToken token =
ctx->cancellation_manager()->get_cancellation_token();
const bool already_cancelled =
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
if (TpuCompile_ShouldTpuCompileOpIgnoreCancellation()) {
return;
}
// Sleep and exit in another thread so the cancellation manager can
// continue running callbacks.
ctx->env()->SchedClosure([ctx, done]() { ExitCountdown(ctx, done); });
});
// If the RPC was cancelled before we registered the cancellation callback,
// don't compile the TPU program.
OP_REQUIRES(ctx, !already_cancelled,
errors::Cancelled("RPC cancelled, not compiling TPU program"));
// We only want to abort the process if a cancellation actually occurs during
// compilation; we must deregister the callback in the success case. It
// doesn't hurt to also deregister the callback in the failure case; the
// CancellationManager ensures that already-registered callbacks will be run
// once cancellation has started.
auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
ctx->cancellation_manager()->DeregisterCallback(token);
done->store(true);
});
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
}
} // namespace tpu
} // namespace tensorflow

View File

@ -53,7 +53,8 @@ class TpuCompileOpKernelCommon {
virtual ~TpuCompileOpKernelCommon() = default;
virtual void Compute(OpKernelContext* ctx) = 0;
void Compute(OpKernelContext* ctx);
virtual Status ComputeInternal(OpKernelContext* ctx) = 0;
// Computes shapes for each argument. Uses both the static shape from the
// metadata, and the dynamic shapes where the static shape is not

View File

@ -95,6 +95,5 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
}
return Status::OK();
}
} // namespace tpu
} // namespace tensorflow

View File

@ -68,7 +68,6 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
// A callback called on exit.
void LogAndExit(int code);
} // namespace tpu
} // namespace tensorflow

View File

@ -31,6 +31,12 @@ void TpuCompile_ToTpuShapeRepresentation(
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
SE_Status* status);
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
// when cancellation is requested for an XLA compile op. Some tests require this
// behavior to be disabled, and we test for this condition with the following
// flag function.
bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
} // extern "C"
struct TfTpu_UtilApiFn {

View File

@ -26,6 +26,7 @@ import (
"encoding/binary"
"fmt"
"io"
"math/bits"
"reflect"
"runtime"
"unsafe"
@ -80,7 +81,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
if dataType == String {
// TF_STRING tensors are encoded as an array of 8-byte offsets
// followed by string data. See c_api.h.
nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value)
nbytes = uintptr(nflattened*8 + int64(byteSizeOfEncodedStrings(val)))
}
var shapePtr *C.int64_t
if len(shape) > 0 {
@ -94,9 +95,22 @@ func NewTensor(value interface{}) (*Tensor, error) {
raw := tensorData(t.c)
buf := bytes.NewBuffer(raw[:0:len(raw)])
if dataType != String {
if err := encodeTensor(buf, val, shape); err != nil {
return nil, err
if isAllArray(val.Type()) {
// We have arrays all the way down, or just primitive types. We can
// just copy the memory in as it is all contiguous.
if err := copyPtr(buf, unpackEFace(value).data, int(val.Type().Size())); err != nil {
return nil, err
}
} else {
// When there are slices involved the memory for each leaf slice may
// not be contiguous with the others or in the order we might
// expect, so we need to work our way down to each slice of
// primitives and copy them individually
if err := encodeTensorWithSlices(buf, val, shape); err != nil {
return nil, err
}
}
if uintptr(buf.Len()) != nbytes {
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
}
@ -112,6 +126,43 @@ func NewTensor(value interface{}) (*Tensor, error) {
return t, nil
}
// isAllArray returns true if type is a primitive type or an array of primitive
// types or an array of ... etc.. When this is true the data we want is
// contiguous in RAM.
func isAllArray(typ reflect.Type) bool {
switch typ.Kind() {
case reflect.Slice:
return false
case reflect.Array:
return isAllArray(typ.Elem())
default:
// We know the type is slices/arrays of slices/arrays of primitive types.
return true
}
}
// eface defines what an interface type actually is: a pointer to type
// information about the encapsulated type and a pointer to the encapsulated
// value.
type eface struct {
rtype unsafe.Pointer
data unsafe.Pointer
}
// unpackEFace gives us an effient way to get us a pointer to the value carried
// in an interface. If you wrap a pointer type in an interface then the pointer
// is directly stored in the interface struct. If you wrap a value type in an
// interface then the compiler copies the value into a newly allocated piece of
// memory and stores a pointer to that memory in the interface. So we're
// guaranteed to get a pointer. Go reflection doesn't expose the pointer to
// value types straightforwardly as it doesn't want you to think you have a
// reference to the original value. But we just want a pointer to make it
// efficient to read the value, so cheating like this should be safe and
// reasonable.
func unpackEFace(obj interface{}) *eface {
return (*eface)(unsafe.Pointer(&obj))
}
// ReadTensor constructs a Tensor with the provided type and shape from the
// serialized tensor contents in r.
//
@ -168,21 +219,152 @@ func (t *Tensor) Shape() []int64 { return t.shape }
// Tensor(int64, 0): int64
// Tensor(float64, 3): [][][]float64
func (t *Tensor) Value() interface{} {
typ := typeOf(t.DataType(), t.Shape())
val := reflect.New(typ)
raw := tensorData(t.c)
if t.DataType() != String {
if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil {
panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err))
shape := t.Shape()
dt := t.DataType()
return decodeTensor(raw, shape, dt).Interface()
}
func decodeTensor(raw []byte, shape []int64, dt DataType) reflect.Value {
// Create a 1-dimensional slice of the base large enough for the data and
// copy the data in.
n := int(numElements(shape))
var (
slice reflect.Value
typ reflect.Type
)
if dt == String {
strs, err := decodeOneDimString(raw, n)
if err != nil {
panic(bug("unable to decode string with shape %v: %v", shape, err))
}
slice = reflect.ValueOf(strs)
typ = slice.Type()
} else {
nflattened := numElements(t.Shape())
d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()}
if err := d.decode(val, t.Shape()); err != nil {
panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err))
}
typ = typeForDataType(dt)
l := n * int(typ.Size())
typ = reflect.SliceOf(typ)
slice = reflect.MakeSlice(typ, n, n)
baseBytes := *(*[]byte)(unsafe.Pointer(&sliceHeader{
Data: unsafe.Pointer(slice.Pointer()),
Len: l,
Cap: l,
}))
copy(baseBytes, raw)
}
return reflect.Indirect(val).Interface()
// Now we have the data in place in the base slice we can add the
// dimensions. We want to walk backwards through the shape. If the shape is
// length 1 or 0 then we're already done.
if len(shape) == 0 {
return slice.Index(0)
}
if len(shape) == 1 {
return slice
}
// We have a special case if the tensor has no data. Our backing slice is
// empty, but we still want to create slices following the shape. In this
// case only the final part of the shape will be 0 and we want to recalculate
// n at this point ignoring that 0.
// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
// want 6 zero length slices to group as follows.
// {{} {}} {{} {}} {{} {}}
if n == 0 {
n = int(numElements(shape[:len(shape)-1]))
}
for i := len(shape) - 2; i >= 0; i-- {
underlyingSize := typ.Elem().Size()
typ = reflect.SliceOf(typ)
subsliceLen := int(shape[i+1])
if subsliceLen != 0 {
n = n / subsliceLen
}
// Just using reflection it is difficult to avoid unnecessary
// allocations while setting up the sub-slices as the Slice function on
// a slice Value allocates. So we end up doing pointer arithmetic!
// Pointer() on a slice gives us access to the data backing the slice.
// We insert slice headers directly into this data.
data := unsafe.Pointer(slice.Pointer())
nextSlice := reflect.MakeSlice(typ, n, n)
for j := 0; j < n; j++ {
// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
setSliceInSlice(nextSlice, j, sliceHeader{
Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
Len: subsliceLen,
Cap: subsliceLen,
})
}
slice = nextSlice
}
return slice
}
// setSliceInSlice sets slice[index] = content.
func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
const sliceSize = unsafe.Sizeof(sliceHeader{})
// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
}
// decodeOneDimString decodes a string tensor into a one-dimensional []string.
func decodeOneDimString(raw []byte, nStrings int) ([]string, error) {
// Start by making an array of all the strings
strs := make([]string, nStrings)
// The first nStrings * 8 bytes of raw are offsets into the second half of
// the raw data. This second half is where the strings are encoded.
offsets := (*(*[]int64)(unsafe.Pointer(&raw)))[:nStrings]
// Reset raw after the offsets. Now the offsets will work relative to raw
raw = raw[nStrings*8:]
// Next we work out the final length of the string data so we can copy the
// good data out of raw (which is owned by the C tensor and won't be safe
// to access if the tensor is freed)
r := bytes.NewReader(raw)
var totalLength int
for _, offset := range offsets {
// At each offset we should find a varint length of a string.
// Errors here should mean the tensor is corrupt.
if _, err := r.Seek(offset, io.SeekStart); err != nil {
return nil, err
}
l, err := binary.ReadUvarint(r)
if err != nil {
return nil, err
}
totalLength += int(l)
}
// Lets allocate a big buffer to carry our string data.
stringData := make([]byte, 0, totalLength)
// Now copy the string data across into our new buffer, keeping track of the
// location of each string in the strs slice.
var cursor int
for i, offset := range offsets {
// At each offset we should find a varint length. Read it
if _, err := r.Seek(offset, io.SeekStart); err != nil {
return nil, err
}
l, err := binary.ReadUvarint(r)
if err != nil {
return nil, err
}
// Then copy the actual string into our large buffer
target := stringData[cursor : cursor+int(l)]
if _, err := r.Read(target); err != nil {
return nil, err
}
// Track where this string data is.
strs[i] = *(*string)(unsafe.Pointer(&target))
cursor += int(l)
}
// So now we have a big slice of strings
return strs, nil
}
// WriteContentsTo writes the serialized contents of t to w.
@ -261,18 +443,18 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
return shape, dt, fmt.Errorf("unsupported type %v", typ)
}
// typeOf converts from a DataType and Shape to the equivalent Go type.
func typeOf(dt DataType, shape []int64) reflect.Type {
var ret reflect.Type
func typeForDataType(dt DataType) reflect.Type {
for _, t := range types {
if dt == DataType(t.dataType) {
ret = t.typ
break
return t.typ
}
}
if ret == nil {
panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
}
panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
}
// typeOf converts from a DataType and Shape to the equivalent Go type.
func typeOf(dt DataType, shape []int64) reflect.Type {
ret := typeForDataType(dt)
for range shape {
ret = reflect.SliceOf(ret)
}
@ -289,109 +471,93 @@ func numElements(shape []int64) int64 {
// byteSizeOfEncodedStrings returns the size of the encoded strings in val.
// val MUST be a string, or a container (array/slice etc.) of strings.
func byteSizeOfEncodedStrings(val interface{}) uintptr {
if s, ok := val.(string); ok {
return uintptr(C.TF_StringEncodedSize(C.size_t(len(s))))
// Tensorflow encodes strings as the varint encoded length followed by the
// string bytes. We could call into the C library to do this but cgo has a heavy
// overhead. So we just do that calculation in Go
func byteSizeOfEncodedStrings(val reflect.Value) int {
if val.Kind() == reflect.String {
return sizeVarUint(uint64(val.Len())) + val.Len()
}
if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
panic(fmt.Sprintf("unexpected type %s", val.Type()))
}
// Otherwise must be an array or slice.
var size uintptr
v := reflect.ValueOf(val)
for i := 0; i < v.Len(); i++ {
size += byteSizeOfEncodedStrings(v.Index(i).Interface())
var size int
for i := 0; i < val.Len(); i++ {
size += byteSizeOfEncodedStrings(val.Index(i))
}
return size
}
// encodeTensor writes v to the specified buffer using the format specified in
// c_api.h. Use stringEncoder for String tensors.
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
switch v.Kind() {
case reflect.Bool:
b := byte(0)
if v.Bool() {
b = 1
}
if err := w.WriteByte(b); err != nil {
return err
}
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
}
case reflect.Array, reflect.Slice:
// If current dimension is a slice, verify that it has the expected size
// Go's type system makes that guarantee for arrays.
if v.Kind() == reflect.Slice {
expected := int(shape[0])
if v.Len() != expected {
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
}
}
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
if len(shape) == 1 && v.Len() > 0 {
switch v.Index(0).Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return binary.Write(w, nativeEndian, v.Interface())
}
}
subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
err := encodeTensor(w, v.Index(i), subShape)
if err != nil {
return err
}
}
default:
return fmt.Errorf("unsupported type %v", v.Type())
// sizeVarUint determines how many bytes it would take to encode the int v as
// an unsigned varint
func sizeVarUint(v uint64) int {
if v < 0x80 {
return 1
}
return nil
bits := bits.Len64(v)
return (bits + 6) / 7
}
// decodeTensor decodes the Tensor from the buffer to ptr using the format
// specified in c_api.h. Use stringDecoder for String tensors.
func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
switch typ.Kind() {
case reflect.Bool:
b, err := r.ReadByte()
// encodeTensorWithSlices writes v to the specified buffer using the format specified in
// c_api.h. Use stringEncoder for String tensors.
func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) error {
// If current dimension is a slice, verify that it has the expected size
// Go's type system makes that guarantee for arrays.
if v.Kind() == reflect.Slice {
expected := int(shape[0])
if v.Len() != expected {
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
}
} else if v.Kind() != reflect.Array {
return fmt.Errorf("unsupported type %v", v.Type())
}
// Once we have just a single dimension we can just copy the data
if len(shape) == 1 && v.Len() > 0 {
elt := v.Index(0)
if !elt.CanAddr() {
panic("cannot take address")
}
ptr := unsafe.Pointer(elt.Addr().Pointer())
return copyPtr(w, ptr, v.Len()*int(elt.Type().Size()))
}
subShape := shape[1:]
for i := 0; i < v.Len(); i++ {
err := encodeTensorWithSlices(w, v.Index(i), subShape)
if err != nil {
return err
}
ptr.Elem().SetBool(b == 1)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err
}
case reflect.Slice:
val := reflect.Indirect(ptr)
val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
if len(shape) == 1 && val.Len() > 0 {
switch val.Index(0).Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return binary.Read(r, nativeEndian, val.Interface())
}
}
for i := 0; i < val.Len(); i++ {
if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
return err
}
}
default:
return fmt.Errorf("unsupported type %v", typ)
}
return nil
}
// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
// this is not inspected by the garbage collector
type sliceHeader struct {
Data unsafe.Pointer
Len int
Cap int
}
// copyPtr copies the backing data for a slice or array directly into w. Note
// we don't need to worry about byte ordering because we want the natural byte
// order for the machine we're running on.
func copyPtr(w *bytes.Buffer, ptr unsafe.Pointer, l int) error {
// Convert our slice header into a []byte so we can call w.Write
b := *(*[]byte)(unsafe.Pointer(&sliceHeader{
Data: ptr,
Len: l,
Cap: l,
}))
_, err := w.Write(b)
return err
}
type stringEncoder struct {
offsets io.Writer
offsets *bytes.Buffer
data []byte
offset uint64
status *status
@ -399,19 +565,18 @@ type stringEncoder struct {
func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
if v.Kind() == reflect.String {
if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil {
if err := copyPtr(e.offsets, unsafe.Pointer(&e.offset), int(unsafe.Sizeof(e.offset))); err != nil {
return err
}
var (
s = v.Interface().(string)
src = C.CString(s)
srcLen = C.size_t(len(s))
dst = (*C.char)(unsafe.Pointer(&e.data[e.offset]))
dstLen = C.size_t(uint64(len(e.data)) - e.offset)
)
e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c))
C.free(unsafe.Pointer(src))
return e.status.Err()
// A string is encoded as the varint length followed by the string bytes.
// We do this in Go to avoid the considerable overhead of a cgo call into
// the tensorflow library
s := v.String()
n := binary.PutUvarint(e.data[e.offset:], uint64(len(s)))
e.offset += uint64(n)
n = copy(e.data[e.offset:], s)
e.offset += uint64(n)
return nil
}
if v.Kind() == reflect.Slice {
@ -430,45 +595,6 @@ func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
return nil
}
type stringDecoder struct {
offsets io.Reader
data []byte
status *status
}
func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error {
if len(shape) == 0 {
var offset uint64
if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil {
return err
}
var (
src = (*C.char)(unsafe.Pointer(&d.data[offset]))
srcLen = C.size_t(len(d.data)) - C.size_t(offset)
dst *C.char
dstLen C.size_t
)
if offset > uint64(len(d.data)) {
return fmt.Errorf("invalid offsets in String Tensor")
}
C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c)
if err := d.status.Err(); err != nil {
return err
}
s := ptr.Interface().(*string)
*s = C.GoStringN(dst, C.int(dstLen))
return nil
}
val := reflect.Indirect(ptr)
val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0])))
for i := 0; i < val.Len(); i++ {
if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil {
return err
}
}
return nil
}
func bug(format string, args ...interface{}) error {
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
}
@ -489,22 +615,3 @@ func isTensorSerializable(dataType DataType) error {
return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
}
}
// nativeEndian is the byte order for the local platform. Used to send back and
// forth Tensors with the C API. We test for endianness at runtime because
// some architectures can be booted into different endian modes.
var nativeEndian binary.ByteOrder
func init() {
buf := [2]byte{}
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
switch buf {
case [2]byte{0xCD, 0xAB}:
nativeEndian = binary.LittleEndian
case [2]byte{0xAB, 0xCD}:
nativeEndian = binary.BigEndian
default:
panic("Could not determine native endianness.")
}
}

View File

@ -18,6 +18,7 @@ package tensorflow
import (
"bytes"
"fmt"
"io"
"reflect"
"testing"
@ -276,6 +277,7 @@ func TestReadTensorReadAll(t *testing.T) {
}
func benchmarkNewTensor(b *testing.B, v interface{}) {
b.ReportAllocs()
for i := 0; i < b.N; i++ {
if t, err := NewTensor(v); err != nil || t == nil {
b.Fatalf("(%v, %v)", t, err)
@ -283,32 +285,68 @@ func benchmarkNewTensor(b *testing.B, v interface{}) {
}
}
func BenchmarkNewTensor(b *testing.B) {
var (
// Some sample sizes from the Inception image labeling model.
// Where input tensors correspond to a 224x224 RGB image
// flattened into a vector.
vector [224 * 224 * 3]int32
)
b.Run("[150528]", func(b *testing.B) { benchmarkNewTensor(b, vector) })
}
func benchmarkValueTensor(b *testing.B, v interface{}) {
t, err := NewTensor(v)
if err != nil {
b.Fatalf("(%v, %v)", t, err)
}
b.ReportAllocs()
b.ResetTimer()
func benchmarkDecodeTensor(b *testing.B, t *Tensor) {
for i := 0; i < b.N; i++ {
_ = t.Value()
}
}
func BenchmarkDecodeTensor(b *testing.B) {
var (
// Some sample sizes from the Inception image labeling model.
// Where input tensors correspond to a 224x224 RGB image
// flattened into a vector.
vector [224 * 224 * 3]int32
)
t, err := NewTensor(vector)
if err != nil {
b.Fatalf("(%v, %v)", t, err)
func BenchmarkTensor(b *testing.B) {
// Some sample sizes from the Inception image labeling model.
// Where input tensors correspond to a 224x224 RGB image
// flattened into a vector.
var vector [224 * 224 * 3]int32
var arrays [100][100][100]int32
l3 := make([][][]float32, 100)
l2 := make([][]float32, 100*100)
l1 := make([]float32, 100*100*100)
for i := range l2 {
l2[i] = l1[i*100 : (i+1)*100]
}
b.Run("[150528]", func(b *testing.B) { benchmarkDecodeTensor(b, t) })
for i := range l3 {
l3[i] = l2[i*100 : (i+1)*100]
}
s1 := make([]string, 100*100*100)
s2 := make([][]string, 100*100)
s3 := make([][][]string, 100)
for i := range s1 {
s1[i] = "cheesit"
}
for i := range s2 {
s2[i] = s1[i*100 : (i+1)*100]
}
for i := range s3 {
s3[i] = s2[i*100 : (i+1)*100]
}
tests := []interface{}{
vector,
arrays,
l1,
l2,
l3,
s1,
s2,
s3,
}
b.Run("New", func(b *testing.B) {
for _, test := range tests {
b.Run(fmt.Sprintf("%T", test), func(b *testing.B) { benchmarkNewTensor(b, test) })
}
})
b.Run("Value", func(b *testing.B) {
for _, test := range tests {
b.Run(fmt.Sprintf("%T", test), func(b *testing.B) { benchmarkValueTensor(b, test) })
}
})
}

View File

@ -440,76 +440,6 @@ typedef struct TfLiteTensor {
// `dims_signature` contains [1, -1, -1, 3]).
const TfLiteIntArray* dims_signature;
} TfLiteTensor;
#else
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// NOTE: This flag is opt-in only at compile time.
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// A structure representing an instance of a node.
// This structure only exhibits the inputs, outputs and user defined data, not
@ -547,6 +477,112 @@ typedef struct TfLiteNode {
// WARNING: This is an experimental interface that is subject to change.
struct TfLiteDelegate* delegate;
} TfLiteNode;
#else
// NOTE: This flag is opt-in only at compile time.
//
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// This struct does not use:
// - allocation
// - buffer_handle
// - data_is_stale
// - delegate
// - dims_signature
// - name
// - sparsity
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
// Specific reduced TfLiteNode struct for TF Micro runtime. This struct contains
// only the minimum fields required to represent a node.
//
// This struct does not use:
// - delegate
// - intermediates
// - temporaries
typedef struct TfLiteNode {
// Inputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* inputs;
// Outputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* outputs;
// Opaque data provided by the node implementer through `Registration.init`.
void* user_data;
// Opaque data provided to the node if the node is a builtin. This is usually
// a structure defined in builtin_op_data.h
void* builtin_data;
// Custom initial data. This is the opaque data provided in the flatbuffer.
// WARNING: This is an experimental interface that is subject to change.
const void* custom_initial_data;
int custom_initial_data_size;
} TfLiteNode;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// WARNING: This is an experimental interface that is subject to change.
//

View File

@ -263,6 +263,43 @@ TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator,
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseCeil(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseConcatenation(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteConcatenationParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteConcatenationParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const ConcatenationOptions* schema_params =
op->builtin_options_as_ConcatenationOptions();
if (schema_params != nullptr) {
params->activation =
ConvertActivation(schema_params->fused_activation_function());
params->axis = schema_params->axis();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
@ -295,6 +332,14 @@ TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseCos(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
@ -339,6 +384,22 @@ TfLiteStatus ParseDequantize(const Operator*, BuiltinOperator, ErrorReporter*,
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseEqual(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseFloor(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
@ -385,6 +446,53 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseGreater(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseGreaterEqual(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLitePoolParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLitePoolParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const Pool2DOptions* schema_params = op->builtin_options_as_Pool2DOptions();
if (schema_params != nullptr) {
params->padding = ConvertPadding(schema_params->padding());
params->stride_width = schema_params->stride_w();
params->stride_height = schema_params->stride_h();
params->filter_width = schema_params->filter_width();
params->filter_height = schema_params->filter_height();
params->activation =
ConvertActivation(schema_params->fused_activation_function());
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
@ -532,6 +640,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
return ParseArgMin(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_AVERAGE_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_CEIL: {
return ParseCeil(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_CONCATENATION: {
return ParseConcatenation(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_CONV_2D: {
return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data);
}
@ -546,11 +667,32 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data);
}
case BuiltinOperator_FLOOR: {
return ParseFloor(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_FULLY_CONNECTED: {
return ParseFullyConnected(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_GREATER: {
return ParseGreater(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_GREATER_EQUAL: {
return ParseGreaterEqual(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_MAX_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_L2_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_QUANTIZE: {
return ParseQuantize(op, op_type, error_reporter, allocator,
builtin_data);
@ -592,23 +734,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_AVERAGE_POOL_2D:
case BuiltinOperator_MAX_POOL_2D:
case BuiltinOperator_L2_POOL_2D: {
auto params = safe_allocator.Allocate<TfLitePoolParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
params->padding = ConvertPadding(pool_params->padding());
params->stride_width = pool_params->stride_w();
params->stride_height = pool_params->stride_h();
params->filter_width = pool_params->filter_width();
params->filter_height = pool_params->filter_height();
params->activation =
ConvertActivation(pool_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
auto params = safe_allocator.Allocate<TfLiteSequenceRNNParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -666,18 +791,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_HASHTABLE_LOOKUP:
// no-op.
return kTfLiteOk;
case BuiltinOperator_CONCATENATION: {
auto params = safe_allocator.Allocate<TfLiteConcatenationParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* concatenation_params =
op->builtin_options_as_ConcatenationOptions()) {
params->activation = ConvertActivation(
concatenation_params->fused_activation_function());
params->axis = concatenation_params->axis();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_MUL: {
auto params = safe_allocator.Allocate<TfLiteMulParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -1102,10 +1215,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_EQUAL:
case BuiltinOperator_EXP:
case BuiltinOperator_EXPAND_DIMS:
case BuiltinOperator_CEIL:
case BuiltinOperator_FLOOR:
case BuiltinOperator_GREATER:
case BuiltinOperator_GREATER_EQUAL:
case BuiltinOperator_HARD_SWISH:
case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL:

View File

@ -91,10 +91,23 @@ TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseCeil(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseConcatenation(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseCos(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
@ -105,11 +118,32 @@ TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseEqual(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseFloor(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseGreater(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseGreaterEqual(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,

View File

@ -49,6 +49,7 @@ cc_library(
":tensor_type",
":util",
"//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util",
@ -84,6 +85,7 @@ cc_library(
":gpu_object",
":opencl_wrapper",
":util",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
@ -330,6 +332,7 @@ cc_library(
cc_library(
name = "gpu_object",
srcs = ["gpu_object.cc"],
hdrs = ["gpu_object.h"],
deps = [
":opencl_wrapper",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/str_split.h"
#include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
@ -457,21 +458,15 @@ std::string Arguments::GetListOfArgs() {
for (auto& t : buffers_) {
const std::string type_name =
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
std::string memory_type;
switch (t.second.memory_type) {
case MemoryType::GLOBAL:
memory_type = "__global";
break;
case MemoryType::CONSTANT:
memory_type = "__constant";
break;
case MemoryType::LOCAL:
memory_type = "__local";
break;
std::string attributes;
for (const auto& attr : t.second.attributes) {
attributes += absl::StrCat(" __attribute__((", attr, "))");
}
AppendArgument(absl::StrCat(memory_type, " ", type_name,
t.second.element_size, "* ", t.first),
&result);
AppendArgument(
absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
ToCLDataType(t.second.data_type, t.second.element_size),
"* ", t.first, attributes),
&result);
}
for (auto& t : image_buffers_) {
AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type),

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include <string>
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite {
@ -51,6 +54,7 @@ GPUResources BufferDescriptor::GetGPUResources(AccessType access_type) const {
desc.access_type = access_type;
desc.element_size = element_size;
desc.memory_type = memory_type;
desc.attributes = attributes;
resources.buffers.push_back({"buffer", desc});
return resources;
}
@ -61,7 +65,7 @@ absl::Status BufferDescriptor::PerformSelector(
if (selector == "Read") {
return PerformReadSelector(args, result);
} else if (selector == "GetPtr") {
return PerformGetPtrSelector(args, result);
return PerformGetPtrSelector(args, template_args, result);
} else {
return absl::NotFoundError(absl::StrCat(
"BufferDescriptor don't have selector with name - ", selector));
@ -80,13 +84,34 @@ absl::Status BufferDescriptor::PerformReadSelector(
}
absl::Status BufferDescriptor::PerformGetPtrSelector(
const std::vector<std::string>& args, std::string* result) const {
if (!args.empty()) {
return absl::NotFoundError(
absl::StrCat("BufferDescriptor GetPtr require zero arguments, but ",
args.size(), " was passed"));
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const {
if (args.size() > 1) {
return absl::NotFoundError(absl::StrCat(
"BufferDescriptor GetPtr require one or zero arguments, but ",
args.size(), " was passed"));
}
if (template_args.size() > 1) {
return absl::NotFoundError(
absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
"arguments, but ",
template_args.size(), " was passed"));
}
std::string conversion;
if (template_args.size() == 1) {
const std::string type_name = ToCLDataType(element_type, element_size);
if (type_name != template_args[0]) {
conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
template_args[0], "*)&");
}
}
if (args.empty()) {
*result = absl::StrCat(conversion, "buffer");
} else if (conversion.empty()) {
*result = absl::StrCat("(buffer + ", args[0], ")");
} else {
*result = absl::StrCat(conversion, "buffer[", args[0], "]");
}
*result = "buffer";
return absl::OkStatus();
}

View File

@ -30,9 +30,10 @@ namespace gpu {
namespace cl {
struct BufferDescriptor : public GPUObjectDescriptor {
DataType element_type; // FLOAT32 or FLOAT16
DataType element_type;
int element_size;
MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args,
@ -42,8 +43,9 @@ struct BufferDescriptor : public GPUObjectDescriptor {
GPUResources GetGPUResources(AccessType access_type) const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformGetPtrSelector(const std::vector<std::string>& args,
std::string* result) const;
absl::Status PerformGetPtrSelector(
const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const;
};
// Buffer represent linear GPU data storage with arbitrary data format.

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
namespace tflite {
namespace gpu {
namespace cl {
std::string MemoryTypeToCLType(MemoryType type) {
switch (type) {
case MemoryType::GLOBAL:
return "__global";
case MemoryType::CONSTANT:
return "__constant";
break;
case MemoryType::LOCAL:
return "__local";
}
return "";
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -56,11 +56,14 @@ struct GPUImageBufferDescriptor {
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
std::string MemoryTypeToCLType(MemoryType type);
struct GPUBufferDescriptor {
DataType data_type;
AccessType access_type;
int element_size;
MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
cl_mem memory;
};

View File

@ -28,21 +28,17 @@ namespace gpu {
namespace cl {
namespace {
std::string GenerateConvolutionTransposedCode(
const OperationDef& op_def, int src_depth, int dst_channels,
const int2& kernel_size, const CLDevice& device,
const std::vector<ElementwiseOperation*>& linked_operations) {
TensorCodeGenerator src_tensor(
"src_data",
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
op_def.src_tensors[0]);
TensorCodeGenerator dst_tensor(
"dst_data",
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
std::string GenerateConvolutionTransposedCode(const OperationDef& op_def,
int src_depth, int dst_channels,
const int2& kernel_size,
Arguments* args) {
args->AddObjectRef(
"src_tensor", AccessType::READ,
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
args->AddObjectRef(
"dst_tensor", AccessType::WRITE,
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
std::string c = GetCommonDefines(op_def.precision);
const std::string channel_x = dst_channels == 1 ? "" : ".x";
const std::vector<std::string> postfix = {channel_x, ".y", ".z", ".w"};
const std::vector<std::string> channel = {".x", ".y", ".z", ".w"};
@ -62,36 +58,33 @@ std::string GenerateConvolutionTransposedCode(
break;
}
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
c += " __constant FLT4* filters";
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
c += " FLT4 bias_value \n";
c += ") {\n";
c += "$0) {\n";
if (op_def.IsBatchSupported()) {
c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / dst_size.w;\n";
c += " int B = linear_id % dst_size.w;\n";
c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
c += " args.src_tensor.SetBatchRef(B);\n";
} else {
c += " int X = get_global_id(0);\n";
}
c += " int Y = get_global_id(1);\n";
c += " if (X >= src_size.x || Y >= src_size.y) return;\n";
c += " if (X >= args.src_tensor.Width() || Y >= args.src_tensor.Height()) "
"return;\n";
c += " " + accum_type + " r[" + std::to_string(kernel_size.y) + "][" +
std::to_string(kernel_size.x) + "];\n";
c += " {\n";
c += " FLT4 src = " + src_tensor.ReadWHSB("X", "Y", "0", batch_id) + ";\n";
c += " FLT4 src = args.src_tensor.Read(X, Y, 0);\n";
int index = 0;
for (int y = 0; y < kernel_size.y; ++y) {
for (int x = 0; x < kernel_size.x; ++x) {
std::string r_s =
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
for (int d = 0; d < dst_channels; ++d) {
c += r_s + postfix[d] + " = dot(src, filters[" + std::to_string(index) +
"]);\n";
c += r_s + postfix[d] + " = dot(src, args.weights.Read(" +
std::to_string(index) + "));\n";
index++;
}
}
@ -100,15 +93,15 @@ std::string GenerateConvolutionTransposedCode(
for (int i = 1; i < src_depth; ++i) {
c += " if (X > " + std::to_string(-i) +
") { // always true, to reduce registers usage\n";
c += " FLT4 src = " +
src_tensor.ReadWHSB("X", "Y", std::to_string(i), batch_id) + ";\n";
c +=
" FLT4 src = args.src_tensor.Read(X, Y, " + std::to_string(i) + ");\n";
for (int y = 0; y < kernel_size.y; ++y) {
for (int x = 0; x < kernel_size.x; ++x) {
std::string r_s =
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
for (int d = 0; d < dst_channels; ++d) {
c += r_s + postfix[d] + " += dot(src, filters[" +
std::to_string(index) + "]);\n";
c += r_s + postfix[d] + " += dot(src, args.weights.Read(" +
std::to_string(index) + "));\n";
index++;
}
}
@ -121,21 +114,16 @@ std::string GenerateConvolutionTransposedCode(
for (int x = 0; x < kernel_size.x; ++x) {
const std::string x_coord = "X + " + std::to_string(x);
const std::string y_coord = "Y + " + std::to_string(y);
c += " if (" + x_coord + " < dst_size.x && " + y_coord +
" < dst_size.y) {\n";
c += " FLT4 result = bias_value;\n";
c += " if (" + x_coord + " < args.dst_tensor.Width() && " + y_coord +
" < args.dst_tensor.Height()) {\n";
c += " FLT4 result = args.weights.Read(" + std::to_string(index) +
");\n";
for (int d = 0; d < dst_channels; ++d) {
c += " result" + channel[d] + " += r[" + std::to_string(y) + "][" +
std::to_string(x) + "]" + postfix[d] + ";\n";
}
const std::string x_3dcoord = op_def.IsBatchSupported()
? "(" + x_coord + ") * dst_size.w + B"
: x_coord;
const LinkingContext context{"result", x_3dcoord, y_coord, "0"};
c += PostProcess(linked_operations, context);
c += " " +
dst_tensor.WriteWHSB("result", x_coord, y_coord, "0", batch_id) +
"\n";
c += " args.dst_tensor.Write(result, " + x_coord + ", " + y_coord +
", 0);\n";
c += " }\n";
}
}
@ -150,19 +138,11 @@ ConvolutionTransposedThin::ConvolutionTransposedThin(
: GPUOperation(definition),
kernel_size_(attr.weights.shape.w, attr.weights.shape.h),
src_channels_(attr.weights.shape.i),
dst_channels_(attr.weights.shape.o) {
float4 bias_value(0.0f);
for (int i = 0; i < attr.weights.shape.o; ++i) {
bias_value[i] = attr.bias.data[i];
}
bias_value_ = FLT4(definition_.precision, bias_value);
}
dst_channels_(attr.weights.shape.o) {}
ConvolutionTransposedThin::ConvolutionTransposedThin(
ConvolutionTransposedThin&& operation)
: GPUOperation(std::move(operation)),
weights_buf_(std::move(operation.weights_buf_)),
bias_value_(std::move(operation.bias_value_)),
kernel_size_(operation.kernel_size_),
src_channels_(operation.src_channels_),
dst_channels_(operation.dst_channels_),
@ -172,8 +152,6 @@ ConvolutionTransposedThin::ConvolutionTransposedThin(
ConvolutionTransposedThin& ConvolutionTransposedThin::operator=(
ConvolutionTransposedThin&& operation) {
if (this != &operation) {
weights_buf_ = std::move(operation.weights_buf_);
bias_value_ = std::move(operation.bias_value_);
std::swap(kernel_size_, operation.kernel_size_);
std::swap(src_channels_, operation.src_channels_);
std::swap(dst_channels_, operation.dst_channels_);
@ -186,9 +164,15 @@ ConvolutionTransposedThin& ConvolutionTransposedThin::operator=(
absl::Status ConvolutionTransposedThin::Compile(
const CreationContext& creation_context) {
const auto code = GenerateConvolutionTransposedCode(
std::string code = GenerateConvolutionTransposedCode(
definition_, DivideRoundUp(src_channels_, 4), dst_channels_, kernel_size_,
*creation_context.device, linked_operations_);
&args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
{{"dst_tensor", element_wise_code}},
&code));
std::vector<CompilerOptions> options;
if (definition_.precision == CalculationsPrecision::F16 &&
@ -202,15 +186,10 @@ absl::Status ConvolutionTransposedThin::Compile(
}
absl::Status ConvolutionTransposedThin::BindArguments() {
kernel_.ResetBindingCounter();
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr()));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_));
return absl::OkStatus();
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
return args_.Bind(kernel_.kernel());
}
int3 ConvolutionTransposedThin::GetGridSize() const {
@ -248,7 +227,7 @@ absl::Status CreateConvolutionTransposedThin(
}
*result = ConvolutionTransposedThin(definition, attr);
RETURN_IF_ERROR(
result->UploadWeights(attr.weights, creation_context.context));
result->UploadData(attr.weights, attr.bias, creation_context.context));
return absl::OkStatus();
}

View File

@ -58,8 +58,9 @@ class ConvolutionTransposedThin : public GPUOperation {
ConvolutionTransposedThin(const OperationDef& definition,
const ConvolutionTransposedAttributes& attr);
template <DataType T>
absl::Status UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
CLContext* context);
absl::Status UploadData(const tflite::gpu::Tensor<OHWI, T>& weights,
const tflite::gpu::Tensor<Linear, T>& biases,
CLContext* context);
template <DataType S, typename T>
void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
@ -68,9 +69,6 @@ class ConvolutionTransposedThin : public GPUOperation {
absl::Status BindArguments();
int3 GetGridSize() const;
Buffer weights_buf_;
FLT4 bias_value_;
int2 kernel_size_;
int src_channels_;
int dst_channels_;
@ -80,25 +78,50 @@ class ConvolutionTransposedThin : public GPUOperation {
};
template <DataType T>
absl::Status ConvolutionTransposedThin::UploadWeights(
const tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
absl::Status ConvolutionTransposedThin::UploadData(
const tflite::gpu::Tensor<OHWI, T>& weights,
const tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
const int src_depth = DivideRoundUp(src_channels_, 4);
const int elements_count =
kernel_size_.x * kernel_size_.y * src_depth * 4 * dst_channels_;
const int flt4_count =
kernel_size_.x * kernel_size_.y * src_depth * dst_channels_;
const int float4_size =
definition_.precision == CalculationsPrecision::F32 ? 16 : 8;
if (definition_.GetDataType() == DataType::FLOAT32) {
std::vector<float4> gpu_data(elements_count);
const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
BufferDescriptor desc;
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
desc.element_size = 4;
desc.memory_type = MemoryType::CONSTANT;
Buffer weights_buffer;
if (f32_weights) {
std::vector<float4> gpu_data(flt4_count);
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
return CreateReadOnlyBuffer(float4_size * elements_count, gpu_data.data(),
context, &weights_buf_);
float4 bias_value(0.0f);
for (int i = 0; i < weights.shape.o; ++i) {
bias_value[i] = biases.data[i];
}
gpu_data.push_back(bias_value);
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(float4) * gpu_data.size(),
gpu_data.data(), context,
&weights_buffer));
} else {
std::vector<half4> gpu_data(elements_count);
std::vector<half4> gpu_data(flt4_count);
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
return CreateReadOnlyBuffer(float4_size * elements_count, gpu_data.data(),
context, &weights_buf_);
half4 bias_value(0.0f);
for (int i = 0; i < weights.shape.o; ++i) {
bias_value[i] = biases.data[i];
}
gpu_data.push_back(bias_value);
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(half4) * gpu_data.size(),
gpu_data.data(), context,
&weights_buffer));
}
args_.AddObject("weights", AccessType::READ,
absl::make_unique<Buffer>(std::move(weights_buffer)),
absl::make_unique<BufferDescriptor>(desc));
return absl::OkStatus();
}
template <DataType S, typename T>

View File

@ -147,6 +147,13 @@ absl::Status TensorDescriptor::PerformSelector(
} else if (selector == "Slices") {
*result = "slices";
return absl::OkStatus();
} else if (selector == "SliceStride") {
if (IsBatchedWidth()) {
*result = "width_batched * height";
} else {
*result = "width * height";
}
return absl::OkStatus();
} else if (selector == "Channels") {
*result = "channels";
return absl::OkStatus();

View File

@ -55,9 +55,12 @@ inline void GetActivationMinMax(FusedActivationFunctionType ac,
}
}
inline float ActivationFunctionWithMinMax(float x, float output_activation_min,
float output_activation_max) {
return std::min(std::max(x, output_activation_min), output_activation_max);
template <typename T>
inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
T output_activation_max) {
using std::max;
using std::min;
return min(max(x, output_activation_min), output_activation_max);
}
// Legacy function, left for compatibility only.

View File

@ -2766,37 +2766,39 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
}
}
inline void SubWithActivation(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int32* input1_data,
const RuntimeShape& input2_shape,
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
ruy::profiler::ScopeLabel label("SubWithActivation/int32");
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.quantized_activation_min,
params.quantized_activation_max);
}
inline void SetActivationMinMax(const ArithmeticParams& params,
int32* activation_min, int32* activation_max) {
*activation_min = params.quantized_activation_min;
*activation_max = params.quantized_activation_max;
}
inline void SubWithActivation(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
ruy::profiler::ScopeLabel label("SubWithActivation/float");
inline void SetActivationMinMax(const ArithmeticParams& params,
float* activation_min, float* activation_max) {
*activation_min = params.float_activation_min;
*activation_max = params.float_activation_max;
}
inline void SetActivationMinMax(const ArithmeticParams& params,
int64_t* activation_min,
int64_t* activation_max) {
*activation_min = params.int64_activation_min;
*activation_max = params.int64_activation_max;
}
template <typename T>
inline void SubWithActivation(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
T activation_min, activation_max;
SetActivationMinMax(params, &activation_min, &activation_max);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min,
params.float_activation_max);
input1_data[i] - input2_data[i], activation_min, activation_max);
}
}

View File

@ -1495,6 +1495,7 @@ inline void GatherNd(const RuntimeShape& params_shape,
}
}
#ifndef TF_LITE_STATIC_MEMORY
template <typename IndicesT = int32>
inline void GatherNdString(const RuntimeShape& params_shape,
const TfLiteTensor* params_data,
@ -1517,6 +1518,7 @@ inline void GatherNdString(const RuntimeShape& params_shape,
}
buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
}
#endif
template <typename IndicesT, typename UpdatesT>
inline void ScatterNd(const RuntimeShape& indices_shape,

View File

@ -260,6 +260,45 @@ inline void BroadcastSubSlow(const ArithmeticParams& params,
NDOpsHelper<N>(output_desc, sub_func);
}
template <int N = 5>
void BroadcastSubSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int64_t* input1_data,
const RuntimeShape& input2_shape,
const int64_t* input2_data,
const RuntimeShape& output_shape, int64_t* output_data) {
ruy::profiler::ScopeLabel label("BroadcastSubSlow/int64");
TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N);
NdArrayDesc<N> desc1;
NdArrayDesc<N> desc2;
NdArrayDesc<N> output_desc;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
// trailing dimension changing most rapidly (channels has the smallest stride,
// typically 1 element).
//
// In generated C code, we store arrays with the dimensions reversed. The
// first dimension has smallest stride.
//
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
auto sub_func = [&](int indexes[N]) {
output_data[SubscriptToIndex(output_desc, indexes)] =
ActivationFunctionWithMinMax(
input1_data[SubscriptToIndex(desc1, indexes)] -
input2_data[SubscriptToIndex(desc2, indexes)],
params.int64_activation_min, params.int64_activation_max);
};
NDOpsHelper<N>(output_desc, sub_func);
}
template <typename T, int N = 5>
void BroadcastSubSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data,
@ -434,40 +473,42 @@ void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
}
}
inline void SubWithActivation(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int32* input1_data,
const RuntimeShape& input2_shape,
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
inline void SetActivationMinMax(const ArithmeticParams& params,
int32* activation_min, int32* activation_max) {
*activation_min = params.quantized_activation_min;
*activation_max = params.quantized_activation_max;
}
inline void SetActivationMinMax(const ArithmeticParams& params,
float* activation_min, float* activation_max) {
*activation_min = params.float_activation_min;
*activation_max = params.float_activation_max;
}
inline void SetActivationMinMax(const ArithmeticParams& params,
int64_t* activation_min,
int64_t* activation_max) {
*activation_min = params.int64_activation_min;
*activation_max = params.int64_activation_max;
}
template <typename T>
inline void SubWithActivation(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
ruy::profiler::ScopeLabel label("SubWithActivation");
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
T activation_min, activation_max;
SetActivationMinMax(params, &activation_min, &activation_max);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.quantized_activation_min,
params.quantized_activation_max);
input1_data[i] - input2_data[i], activation_min, activation_max);
}
}
inline void SubWithActivation(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min,
params.float_activation_max);
}
}
} // namespace reference_ops
} // namespace tflite

View File

@ -765,12 +765,17 @@ struct ArithmeticParams {
int input1_shift;
int32 input2_multiplier;
int input2_shift;
// TODO(b/158622529): Union the following activation params.
// uint8, etc, activation params.
int32 quantized_activation_min;
int32 quantized_activation_max;
// float activation params.
float float_activation_min;
float float_activation_max;
// int64 activation params.
int64_t int64_activation_min;
int64_t int64_activation_max;
// Processed output dimensions.
// Let input "a" be the one that broadcasts in the faster-changing dimension.
@ -1114,6 +1119,12 @@ inline void SetActivationParams(int32 min, int32 max, P* params) {
params->quantized_activation_max = max;
}
template <typename P>
inline void SetActivationParams(int64_t min, int64_t max, P* params) {
params->int64_activation_min = min;
params->int64_activation_max = max;
}
template <typename P>
inline void GetActivationParams(const P& params, int32* min, int32* max) {
*min = params.quantized_activation_min;
@ -1126,6 +1137,11 @@ inline void GetActivationParams(const P& params, float* min, float* max) {
*max = params.float_activation_max;
}
template <typename P>
inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
*min = params.int64_activation_min;
*max = params.int64_activation_max;
}
} // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_

View File

@ -44,6 +44,7 @@ inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index) {
return &context->tensors[node->outputs->data[index]];
}
#ifndef TF_LITE_STATIC_MEMORY
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->temporaries->data[index]];
@ -52,11 +53,12 @@ inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) {
return &context->tensors[node->intermediates->data[index]];
}
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
inline int NumIntermediates(const TfLiteNode* node) {
return node->intermediates->size;
}
#endif // TF_LITE_STATIC_MEMORY
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
inline int64_t NumElements(const TfLiteIntArray* dims) {
int64_t count = 1;

View File

@ -142,8 +142,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version */ 1,
/* max_version */ 2);
AddBuiltin(BuiltinOperator_SUB, Register_SUB(),
/* min_version */ 1,
/* max_version */ 5);
/* min_version = */ 1,
/* max_version = */ 5);
AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(),
/* min_version = */ 1,
/* max_version = */ 4);

View File

@ -340,6 +340,11 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
EvalSubImpl<kernel_type, float>(context, node, params, data, input1,
input2, requires_broadcast, output);
break;
case kTfLiteInt64:
EvalSubImpl<kernel_type, int64_t>(context, node, params, data, input1,
input2, requires_broadcast, output);
break;
default:
TF_LITE_KERNEL_LOG(context, "output type %s is not supported.",
TfLiteTypeGetName(output->type));
@ -434,7 +439,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
output->type == kTfLiteInt64) {
EvalSub<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
output->type == kTfLiteInt16) {

View File

@ -63,6 +63,13 @@ class IntegerSubOpModel : public BaseSubOpModel {
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
};
class Int64SubOpModel : public BaseSubOpModel {
public:
using BaseSubOpModel::BaseSubOpModel;
std::vector<int64_t> GetOutput() { return ExtractVector<int64_t>(output_); }
};
class QuantizedSubOpModel : public BaseSubOpModel {
public:
using BaseSubOpModel::BaseSubOpModel;
@ -213,6 +220,57 @@ TEST(IntegerSubOpModel, WithBroadcast) {
}
}
TEST(Int64SubOpModel, NoActivation) {
Int64SubOpModel m({TensorType_INT64, {1, 2, 2, 1}},
{TensorType_INT64, {1, 2, 2, 1}}, {TensorType_INT64, {}},
ActivationFunctionType_NONE);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8});
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3}));
}
TEST(Int64SubOpModel, ActivationRELU_N1_TO_1) {
Int64SubOpModel m({TensorType_INT64, {1, 2, 2, 1}},
{TensorType_INT64, {1, 2, 2, 1}}, {TensorType_INT64, {}},
ActivationFunctionType_RELU_N1_TO_1);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8});
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 0, 1, 1}));
}
TEST(Int64SubOpModel, VariousInputShapes) {
std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
Int64SubOpModel m({TensorType_INT64, test_shapes[i]},
{TensorType_INT64, test_shapes[i]},
{TensorType_INT64, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5, 11, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3, 0, 19}))
<< "With shape number " << i;
}
}
TEST(Int64SubOpModel, WithBroadcast) {
std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}, {1, 3, 1, 2, 1}};
for (int i = 0; i < test_shapes.size(); ++i) {
Int64SubOpModel m({TensorType_INT64, test_shapes[i]},
{TensorType_INT64, {}}, // always a scalar
{TensorType_INT64, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
m.PopulateTensor<int64_t>(m.input2(), {1});
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear({-21, 1, 6, 7, 10, 19})))
<< "With shape number " << i;
}
}
template <TensorType tensor_type, typename integer_dtype>
void QuantizedTestsNoActivation() {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);

View File

@ -53,19 +53,14 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
// There is 1 output at index 3 in the tensors array.
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no temporaries.
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(conv_params);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -66,19 +66,14 @@ TfLiteStatus ValidateDepthwiseConvGoldens(TfLiteTensor* tensors,
// There is 1 output at index 3 in the tensors array.
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no intermediates.
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -33,20 +33,23 @@ limitations under the License.
namespace {
// Create an area of memory to use for input, output, and intermediate arrays.
constexpr int tensor_arena_size = 73 * 1024;
uint8_t tensor_arena[tensor_arena_size];
// Align arena to 16 bytes to avoid alignment warnings on certain platforms.
constexpr int tensor_arena_size = 21 * 1024;
alignas(16) uint8_t tensor_arena[tensor_arena_size];
// A random number generator seed to generate input values.
constexpr int kRandomSeed = 42;
// NOLINTNEXTLINE
MicroBenchmarkRunner<int16_t> runner(g_keyword_scrambled_model_data,
tensor_arena, tensor_arena_size,
kRandomSeed);
MicroBenchmarkRunner<int16_t>& GetBenchmarkRunner() {
// NOLINTNEXTLINE
static MicroBenchmarkRunner<int16_t> runner(
g_keyword_scrambled_model_data, tensor_arena, tensor_arena_size, 0);
return runner;
}
void KeywordRunTenIerations() {
// TODO(b/152644476): Add a way to run more than a single deterministic input.
for (int i = 0; i < 10; i++) {
runner.RunSingleIterationRandomInput();
GetBenchmarkRunner().RunSingleIterationRandomInput();
}
}
@ -54,7 +57,7 @@ void KeywordRunTenIerations() {
TF_LITE_MICRO_BENCHMARKS_BEGIN
TF_LITE_MICRO_BENCHMARK(runner.RunSingleIterationRandomInput());
TF_LITE_MICRO_BENCHMARK(GetBenchmarkRunner().RunSingleIterationRandomInput());
TF_LITE_MICRO_BENCHMARK(KeywordRunTenIerations());

View File

@ -27,10 +27,6 @@ limitations under the License.
#include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h"
// Create an area of memory to use for input, output, and intermediate arrays.
constexpr int tensor_arena_size = 73 * 1024;
uint8_t tensor_arena[tensor_arena_size];
/*
* Person Detection benchmark. Evaluates runtime performance of the visual
* wakewords person detection model. This is the same model found in
@ -40,24 +36,28 @@ uint8_t tensor_arena[tensor_arena_size];
namespace {
// Create an area of memory to use for input, output, and intermediate arrays.
// Align arena to 16 bytes to avoid alignment warnings on certain platforms.
constexpr int tensor_arena_size = 95 * 1024;
uint8_t tensor_arena[tensor_arena_size];
alignas(16) uint8_t tensor_arena[tensor_arena_size];
// NOLINTNEXTLINE
MicroBenchmarkRunner<uint8_t> runner(g_person_detect_model_data, tensor_arena,
tensor_arena_size, 0);
MicroBenchmarkRunner<uint8_t>& GetBenchmarkRunner() {
// NOLINTNEXTLINE
static MicroBenchmarkRunner<uint8_t> runner(
g_person_detect_model_data, tensor_arena, tensor_arena_size, 0);
return runner;
}
void PersonDetectionTenIerationsWithPerson() {
// TODO(b/152644476): Add a way to run more than a single deterministic input.
for (int i = 0; i < 10; i++) {
runner.RunSingleIterationCustomInput(g_person_data);
GetBenchmarkRunner().RunSingleIterationCustomInput(g_person_data);
}
}
void PersonDetectionTenIerationsWithoutPerson() {
// TODO(b/152644476): Add a way to run more than a single deterministic input.
for (int i = 0; i < 10; i++) {
runner.RunSingleIterationCustomInput(g_no_person_data);
GetBenchmarkRunner().RunSingleIterationCustomInput(g_no_person_data);
}
}
@ -65,7 +65,8 @@ void PersonDetectionTenIerationsWithoutPerson() {
TF_LITE_MICRO_BENCHMARKS_BEGIN
TF_LITE_MICRO_BENCHMARK(runner.RunSingleIterationCustomInput(g_person_data));
TF_LITE_MICRO_BENCHMARK(
GetBenchmarkRunner().RunSingleIterationCustomInput(g_person_data));
TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithPerson());
TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithoutPerson());

View File

@ -60,12 +60,10 @@ void TestReluFloat(const int* input_dims_data, const float* input_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -116,12 +114,10 @@ void TestRelu6Float(const int* input_dims_data, const float* input_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -177,12 +173,10 @@ void TestReluUint8(const int* input_dims_data, const float* input_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -242,12 +236,10 @@ void TestRelu6Uint8(const int* input_dims_data, const float* input_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -306,12 +298,10 @@ void TestReluInt8(const int* input_dims_data, const float* input_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -372,12 +362,10 @@ void TestRelu6Int8(const int* input_dims_data, const float* input_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}

View File

@ -89,18 +89,14 @@ void ValidateAddGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -46,17 +46,13 @@ void ValidateArgMinMaxGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}

View File

@ -46,17 +46,13 @@ void TestCeil(const int* input_dims_data, const float* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
for (int i = 0; i < output_dims_count; ++i) {

View File

@ -56,17 +56,12 @@ TfLiteNode PrepareCircularBufferInt8(const int* input_dims_data,
// There is one output - tensor 1.
const int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no intermediates.
const int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -104,17 +99,12 @@ TfLiteStatus InvokeCircularBufferInt8(const int* input_dims_data,
// There is one output - tensor 1.
const int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no intermediates.
const int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
node->inputs = inputs_array;
node->outputs = outputs_array;
node->temporaries = temporaries_array;
node->builtin_data = nullptr;
node->custom_initial_data = nullptr;
node->custom_initial_data_size = 0;
node->delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);

View File

@ -44,18 +44,14 @@ void TestComparison(tflite::BuiltinOperator op, TfLiteTensor* tensors,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
const int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -57,19 +57,18 @@ void TestConcatenateTwoInputs(std::initializer_list<int> input1_dims_data,
.activation = kTfLiteActNone // Only activation supported in this impl
};
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -116,19 +115,18 @@ void TestConcatenateQuantizedTwoInputs(
.activation = kTfLiteActNone // Only activation supported in this impl
};
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -76,18 +76,14 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(conv_params);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node));

View File

@ -73,18 +73,14 @@ TfLiteStatus ValidateDepthwiseConvGoldens(const T* expected_output_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node));
}

View File

@ -48,18 +48,14 @@ void ValidateDequantizeGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -54,23 +54,18 @@ void TestElementwiseFloat(tflite::BuiltinOperator op,
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
auto inputs_array_data = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInitializer(inputs_array_data);
auto outputs_array_data = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer(outputs_array_data);
auto temporaries_array_data = {0};
TfLiteIntArray* temporaries_array =
IntArrayFromInitializer(temporaries_array_data);
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -119,19 +114,19 @@ void TestElementwiseBool(tflite::BuiltinOperator op,
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 1});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -47,18 +47,13 @@ void TestFloor(const int* input_dims_data, const float* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int intermediates_array_data[] = {0};
TfLiteIntArray* temporaries_array =
IntArrayFromInts(intermediates_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
for (int i = 0; i < output_dims_count; ++i) {

View File

@ -72,18 +72,14 @@ TfLiteStatus TestFullyConnectedFloat(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node));
}
@ -151,18 +147,14 @@ TfLiteStatus TestFullyConnectedQuantized(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node));

View File

@ -112,18 +112,14 @@ void TestL2Normalization(const int* input_dims_data, const T* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));

View File

@ -51,19 +51,18 @@ void TestLogicalOp(tflite::BuiltinOperator op,
const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -62,12 +62,10 @@ void TestLogisticFloat(std::initializer_list<int> input_dims_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -122,12 +120,10 @@ void TestLogisticInt8(std::initializer_list<int> input_dims_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}

View File

@ -52,19 +52,18 @@ void TestMaxMinFloat(tflite::BuiltinOperator op,
const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -108,19 +107,18 @@ void TestMaxMinQuantized(
const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -162,19 +160,18 @@ void TestMaxMinQuantizedInt32(
const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -76,7 +76,6 @@ void TestMulFloat(std::initializer_list<int> input1_dims_data,
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -148,7 +147,6 @@ void TestMulQuantized(std::initializer_list<int> input1_dims_data,
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -50,17 +50,14 @@ void TestNegFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));

View File

@ -64,19 +64,19 @@ void TestPackTwoInputsFloat(std::initializer_list<int> input1_dims_data,
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -141,19 +141,19 @@ void TestPackThreeInputsFloat(std::initializer_list<int> input1_dims_data,
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({3, 0, 1, 2});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 3});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {3, 0, 1, 2};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -214,19 +214,18 @@ void TestPackTwoInputsQuantized(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -284,19 +283,18 @@ void TestPackTwoInputsQuantized32(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -39,17 +39,13 @@ TfLiteStatus ValidatePadGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
TF_LITE_ENSURE_EQ(&context, kTfLiteOk,
registration->prepare(&context, &node));
@ -76,17 +72,13 @@ TfLiteStatus ValidatePadV2Goldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
// Prepare should catch dimension mismatches.
TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -66,18 +66,14 @@ void TestAveragePoolingFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -138,18 +134,14 @@ void TestAveragePoolingQuantized(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -209,18 +201,14 @@ void TestMaxPoolFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -283,18 +271,14 @@ void TestMaxPoolQuantized(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}

View File

@ -58,16 +58,13 @@ void TestPreluFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -122,16 +119,13 @@ void TestPreluQuantized(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}

View File

@ -50,18 +50,14 @@ void ValidateQuantizeGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -74,12 +74,10 @@ TfLiteStatus ValidateReduceGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(params);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -29,56 +29,29 @@ namespace tflite {
namespace testing {
namespace {
// If expected output is empty, the test is expected to fail.
template <typename T>
void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
void TestReshapeImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output_tensor,
std::initializer_list<T> expected_output,
std::initializer_list<int> expected_dims,
bool expect_failure) {
TfLiteContext context;
TfLiteTensor tensors[3];
TfLiteNode node;
if (shape_tensor == nullptr) {
constexpr int inputs_size = 1;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *output_tensor,
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
node.inputs = IntArrayFromInitializer({1, 0});
node.outputs = IntArrayFromInitializer({1, 1});
} else {
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *shape_tensor;
tensors[2] = *output_tensor;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
node.inputs = IntArrayFromInitializer({2, 0, 1});
node.outputs = IntArrayFromInitializer({1, 2});
}
::tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_RESHAPE);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
void* user_data = nullptr;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
node->user_data = user_data;
node->builtin_data = nullptr;
node->custom_initial_data = nullptr;
node->custom_initial_data_size = 0;
TF_LITE_MICRO_EXPECT_EQ(registration->init, nullptr);
TF_LITE_MICRO_EXPECT_EQ(registration->free, nullptr);
if (registration->prepare) {
// Error can happen either in Prepare or eval stage.
auto status = registration->prepare(&context, &node);
auto status = registration->prepare(context, node);
if (status != kTfLiteOk && expect_failure) {
return;
} else {
@ -86,11 +59,10 @@ void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
}
}
if (expect_failure) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
registration->invoke(&context, &node));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, registration->invoke(context, node));
return;
}
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(context, node));
const int output_dims_count = ElementCount(*output_tensor->dims);
const T* output_data = GetTensorData<T>(output_tensor);
@ -105,6 +77,59 @@ void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
}
}
// If expected output is empty, the test is expected to fail.
template <typename T>
void TestReshapeWithShapeImpl(TfLiteTensor* input_tensor,
TfLiteTensor* shape_tensor,
TfLiteTensor* output_tensor,
std::initializer_list<T> expected_output,
std::initializer_list<int> expected_dims,
bool expect_failure) {
TfLiteContext context;
TfLiteTensor tensors[3];
TfLiteNode node;
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *shape_tensor;
tensors[2] = *output_tensor;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
int inputs_data[] = {2, 0, 1};
node.inputs = IntArrayFromInts(inputs_data);
int outputs_data[] = {1, 2};
node.outputs = IntArrayFromInts(outputs_data);
TestReshapeImpl(&context, &node, output_tensor, expected_output,
expected_dims, expect_failure);
}
// If expected output is empty, the test is expected to fail.
template <typename T>
void TestReshapeWithoutShapeImpl(TfLiteTensor* input_tensor,
TfLiteTensor* output_tensor,
std::initializer_list<T> expected_output,
std::initializer_list<int> expected_dims,
bool expect_failure) {
TfLiteContext context;
TfLiteTensor tensors[3];
TfLiteNode node;
constexpr int inputs_size = 1;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *output_tensor,
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
int inputs_data[] = {1, 0};
node.inputs = IntArrayFromInts(inputs_data);
int outputs_data[] = {1, 1};
node.outputs = IntArrayFromInts(outputs_data);
TestReshapeImpl(&context, &node, output_tensor, expected_output,
expected_dims, expect_failure);
}
template <typename T = float, TfLiteType tensor_input_type = kTfLiteFloat32>
void TestReshape(std::initializer_list<int> input_dims_data,
std::initializer_list<T> input_data,
@ -122,14 +147,14 @@ void TestReshape(std::initializer_list<int> input_dims_data,
TfLiteTensor output_tensor =
CreateTensor<T, tensor_input_type>(output_data, output_dims);
// Reshape param is passed as op's param.
TestReshapeImpl<T>(&input_tensor, nullptr, &output_tensor, expected_output,
expected_dims, expect_failure);
TestReshapeWithoutShapeImpl<T>(&input_tensor, &output_tensor, expected_output,
expected_dims, expect_failure);
// Reshape param is passed as a tensor.
TfLiteIntArray* shape_dims = IntArrayFromInitializer(shape_dims_data);
auto shape_tensor =
CreateTensor<int32_t, kTfLiteInt32>(shape_data, shape_dims);
TestReshapeImpl<T>(&input_tensor, &shape_tensor, &output_tensor,
expected_output, expected_dims, expect_failure);
TestReshapeWithShapeImpl<T>(&input_tensor, &shape_tensor, &output_tensor,
expected_output, expected_dims, expect_failure);
}
} // namespace
} // namespace testing
@ -192,19 +217,20 @@ TF_LITE_MICRO_TEST(InvalidShape) {
using tflite::testing::CreateFloatTensor;
using tflite::testing::IntArrayFromInitializer;
using tflite::testing::IntArrayFromInts;
TfLiteIntArray* input_dims = IntArrayFromInitializer({3, 1, 2, 2});
int input_dims_data[] = {3, 1, 2, 2};
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
auto input_data = {3.0f};
auto input_tensor = CreateFloatTensor(input_data, input_dims);
float output_data[4];
int output_dims_data[6] = {2, 2, 1, 2, 2, 1};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
auto output_tensor = CreateFloatTensor(output_data, output_dims);
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor
nullptr, // shape_tensor
&output_tensor, // output_tensor
{}, // expected_output
{}, // expected_dims
true // expect failure
tflite::testing::TestReshapeWithoutShapeImpl<float>(
&input_tensor, // input_tensor
&output_tensor, // output_tensor
{}, // expected_output
{}, // expected_dims
true // expect failure
);
}
@ -255,29 +281,32 @@ TF_LITE_MICRO_TEST(LegacyScalarOutput) {
using tflite::testing::CreateFloatTensor;
using tflite::testing::IntArrayFromInitializer;
using tflite::testing::IntArrayFromInts;
TfLiteIntArray* input_dims = IntArrayFromInitializer({1, 1});
int input_dims_data[] = {1, 1};
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
auto input_data = {3.0f};
auto input_tensor = CreateFloatTensor(input_data, input_dims);
float output_data[1];
int output_dims_data[2] = {1, 0};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
auto output_tensor = CreateFloatTensor(output_data, output_dims);
TfLiteIntArray* shape_dims = tflite::testing::IntArrayFromInitializer({1, 0});
int shape_dims_data[] = {1, 0};
TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
auto shape_tensor =
tflite::testing::CreateTensor<int32_t, kTfLiteInt32>({0}, shape_dims);
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor
&shape_tensor, // shape_tensor
&output_tensor, // output_tensor
{}, // expected_output
{}, // expected_dims
true // expect failure
tflite::testing::TestReshapeWithShapeImpl<float>(
&input_tensor, // input_tensor
&shape_tensor, // shape_tensor
&output_tensor, // output_tensor
{}, // expected_output
{}, // expected_dims
true // expect failure
);
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor
nullptr, // shape_tensor
&output_tensor, // output_tensor
{3}, // expected_output
{}, // expected_dims
false // expect failure
tflite::testing::TestReshapeWithoutShapeImpl<float>(
&input_tensor, // input_tensor
&output_tensor, // output_tensor
{3}, // expected_output
{}, // expected_dims
false // expect failure
);
}

View File

@ -78,18 +78,14 @@ void TestResizeNearestNeighbor(const int* input_dims_data, const T* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));

View File

@ -46,17 +46,13 @@ void TestRound(const int* input_dims_data, const float* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
for (int i = 0; i < output_dims_count; ++i) {

View File

@ -59,18 +59,14 @@ void TestSoftmaxFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -124,18 +120,14 @@ void TestSoftmaxQuantized(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -188,18 +180,14 @@ void TestSoftmaxQuantizedSigned(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -76,19 +76,19 @@ void TestSplitTwoOutputsFloat(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -179,19 +179,19 @@ void TestSplitFourOutputsFloat(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({4, 2, 3, 4, 5});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {4, 2, 3, 4, 5};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -275,19 +275,19 @@ void TestSplitTwoOutputsQuantized(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -362,19 +362,19 @@ void TestSplitTwoOutputsQuantized32(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -100,12 +100,10 @@ void TestStrideSlide(std::initializer_list<int> input_shape,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
if (expect_prepare_err) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,

View File

@ -89,18 +89,14 @@ void ValidateSubGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -201,7 +201,6 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
node.builtin_data = reinterpret_cast<void*>(&params);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, prepare_status);
@ -275,7 +274,6 @@ void ValidateIntegerSVDFGoldens(const int batch_size, const int num_units,
node.builtin_data = reinterpret_cast<void*>(&params);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -62,12 +62,10 @@ void TestTanhFloat(std::initializer_list<int> input_dims_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}
@ -122,12 +120,10 @@ void TestTanhInt8(std::initializer_list<int> input_dims_data,
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data;
node.builtin_data = nullptr;
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
}

View File

@ -79,19 +79,18 @@ void TestUnpackThreeOutputsFloat(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {3, 1, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -156,19 +155,18 @@ void TestUnpackOneOutputFloat(std::initializer_list<int> input_dims_data,
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 1});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -245,19 +243,18 @@ void TestUnpackThreeOutputsQuantized(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {3, 1, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -338,19 +335,18 @@ void TestUnpackThreeOutputsQuantized32(
if (registration->init) {
user_data = registration->init(&context, nullptr, 0);
}
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3});
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {3, 1, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node;
node.inputs = inputs_array;
node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -45,8 +45,8 @@ constexpr int kKeywordModelNodeAndRegistrationCount = 15;
// Run this test with '--copt=-DTF_LITE_MICRO_OPTIMIZED_RUNTIME' to get
// optimized memory runtime values:
#ifdef TF_LITE_STATIC_MEMORY
constexpr int kKeywordModelTotalSize = 18448;
constexpr int kKeywordModelTailSize = 17776;
constexpr int kKeywordModelTotalSize = 18080;
constexpr int kKeywordModelTailSize = 17408;
#else
constexpr int kKeywordModelTotalSize = 21040;
constexpr int kKeywordModelTailSize = 20368;
@ -65,8 +65,8 @@ constexpr int kTestConvModelNodeAndRegistrationCount = 7;
// NOTE: These values are measured on x86-64:
// TODO(b/158651472): Consider auditing these values on non-64 bit systems.
#ifdef TF_LITE_STATIC_MEMORY
constexpr int kTestConvModelTotalSize = 10960;
constexpr int kTestConvModelTailSize = 3216;
constexpr int kTestConvModelTotalSize = 10784;
constexpr int kTestConvModelTailSize = 3040;
#else
constexpr int kTestConvModelTotalSize = 11680;
constexpr int kTestConvModelTailSize = 3936;

View File

@ -128,26 +128,20 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddAveragePool2D() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
*tflite::ops::micro::Register_AVERAGE_POOL_2D(),
ParseOpData);
ParsePool);
}
TfLiteStatus AddCeil() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_CEIL,
*tflite::ops::micro::Register_CEIL(), ParseOpData);
*tflite::ops::micro::Register_CEIL(), ParseCeil);
}
TfLiteStatus AddConcatenation() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_CONCATENATION,
*tflite::ops::micro::Register_CONCATENATION(),
ParseOpData);
ParseConcatenation);
}
TfLiteStatus AddConv2D() {
@ -156,10 +150,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddCos() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_COS, *tflite::ops::micro::Register_COS(),
ParseOpData);
ParseCos);
}
TfLiteStatus AddDepthwiseConv2D() {
@ -175,17 +167,13 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_EQUAL,
*tflite::ops::micro::Register_EQUAL(), ParseOpData);
*tflite::ops::micro::Register_EQUAL(), ParseEqual);
}
TfLiteStatus AddFloor() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_FLOOR,
*tflite::ops::micro::Register_FLOOR(), ParseOpData);
*tflite::ops::micro::Register_FLOOR(), ParseFloor);
}
TfLiteStatus AddFullyConnected() {
@ -195,18 +183,14 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddGreater() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_GREATER,
*tflite::ops::micro::Register_GREATER(), ParseOpData);
*tflite::ops::micro::Register_GREATER(), ParseGreater);
}
TfLiteStatus AddGreaterEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_GREATER_EQUAL,
*tflite::ops::micro::Register_GREATER_EQUAL(),
ParseOpData);
ParseGreaterEqual);
}
TfLiteStatus AddL2Normalization() {
@ -274,10 +258,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
}
TfLiteStatus AddMaxPool2D() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
*tflite::ops::micro::Register_MAX_POOL_2D(), ParseOpData);
*tflite::ops::micro::Register_MAX_POOL_2D(), ParsePool);
}
TfLiteStatus AddMean() {

View File

@ -100,7 +100,6 @@ const std::map<string, string>& GetKnownBrokenTests() {
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"},
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"},
{R"(^\/sub.*dtype=tf\.int64)", "119126484"},
{R"(^\/div.*dtype=tf\.int64)", "119126484"},
{R"(^\/mul.*dtype=tf\.int64)", "119126484"},
{R"(^\/add.*dtype=tf\.int64)", "119126484"},

View File

@ -61,7 +61,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kSub, 1}, "1.6.0"},
{{OperatorType::kSub, 2}, "1.14.0"},
{{OperatorType::kSub, 3}, "1.15.0"},
{{OperatorType::kSub, 4}, "1.15.0"},
{{OperatorType::kSub, 4}, kPendingReleaseOpVersion},
{{OperatorType::kSub, 5}, kPendingReleaseOpVersion},
{{OperatorType::kDiv, 1}, "1.6.0"},
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},

View File

@ -440,76 +440,6 @@ typedef struct TfLiteTensor {
// `dims_signature` contains [1, -1, -1, 3]).
const TfLiteIntArray* dims_signature;
} TfLiteTensor;
#else
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// NOTE: This flag is opt-in only at compile time.
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// A structure representing an instance of a node.
// This structure only exhibits the inputs, outputs and user defined data, not
@ -547,6 +477,112 @@ typedef struct TfLiteNode {
// WARNING: This is an experimental interface that is subject to change.
struct TfLiteDelegate* delegate;
} TfLiteNode;
#else
// NOTE: This flag is opt-in only at compile time.
//
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// This struct does not use:
// - allocation
// - buffer_handle
// - data_is_stale
// - delegate
// - dims_signature
// - name
// - sparsity
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
// Specific reduced TfLiteNode struct for TF Micro runtime. This struct contains
// only the minimum fields required to represent a node.
//
// This struct does not use:
// - delegate
// - intermediates
// - temporaries
typedef struct TfLiteNode {
// Inputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* inputs;
// Outputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* outputs;
// Opaque data provided by the node implementer through `Registration.init`.
void* user_data;
// Opaque data provided to the node if the node is a builtin. This is usually
// a structure defined in builtin_op_data.h
void* builtin_data;
// Custom initial data. This is the opaque data provided in the flatbuffer.
// WARNING: This is an experimental interface that is subject to change.
const void* custom_initial_data;
int custom_initial_data_size;
} TfLiteNode;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// WARNING: This is an experimental interface that is subject to change.
//

View File

@ -466,12 +466,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
op_sig.output_types.at(0) == TensorType_INT16) {
if (op_sig.options.addsub.pot_scale_int16) {
return 5;
} else {
return 4;
}
}
}
if (op_sig.options.addsub.need_broadcast &&
op_sig.options.addsub.num_dims > 4) {
if (!op_sig.input_types.empty() &&
op_sig.input_types.at(0) == TensorType_INT64) {
return 4;
}
if (op_sig.options.broadcast.need_broadcast &&
op_sig.options.broadcast.num_dims > 4) {
return 3;
}
if (op_sig.input_types.at(0) == TensorType_INT8) {

View File

@ -296,6 +296,14 @@ TEST(OpVersionTest, VersioningSubTest) {
SimpleVersioningTest(BuiltinOperator_SUB);
}
TEST(OpVersionTest, VersioningSub4Test) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_SUB,
.input_types = std::vector<TensorType>{TensorType_INT64},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
}
void SimpleMulVersioningTest(TensorType data_type, float multiplier,
int version) {
OpSignature fake_op_sig = {

View File

@ -79,6 +79,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_SUB, 1}, "1.6.0"},
{{BuiltinOperator_SUB, 2}, "1.14.0"},
{{BuiltinOperator_SUB, 3}, kPendingReleaseVersion},
{{BuiltinOperator_SUB, 4}, kPendingReleaseVersion},
{{BuiltinOperator_DENSIFY, 1}, "2.2.0"},
{{BuiltinOperator_DIV, 1}, "1.6.0"},
{{BuiltinOperator_DIV, 2}, kPendingReleaseVersion},

View File

@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
# This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 6, 23)
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 6, 24)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None

View File

@ -1116,6 +1116,8 @@ py_library(
"//tensorflow/python:framework_ops",
"//tensorflow/python:gradients_impl",
"//tensorflow/python:init_ops",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python:math_ops",
"//tensorflow/python:training",
"//tensorflow/python:util",
"//tensorflow/python:variable_scope",
@ -1124,7 +1126,6 @@ py_library(
"//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context",
"//tensorflow/python/eager:test",
"//tensorflow/python/keras/layers",
"//third_party/py/numpy",
],
)

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import tempfile
@ -38,11 +39,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import core
from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables
@ -114,12 +116,19 @@ def _events_from_logdir(test_case, logdir):
class DistributionTestBase(test.TestCase):
"""Some tests that should work with any DistributionStrategy."""
def _create_variable_like_keras_dense_layer(self, name, shape, dtype):
initializer = functools.partial(
init_ops_v2.GlorotUniform(), shape, dtype=dtype)
return variables.Variable(
initial_value=initializer, name=name, trainable=True)
def _test_minimize_loss_eager(self, d):
with d.scope():
l = core.Dense(1, use_bias=False)
kernel = self._create_variable_like_keras_dense_layer(
name="kernel", shape=(1, 1), dtype=dtypes.float32)
def loss(x):
y = array_ops.reshape(l(x), []) - array_ops.identity(1.)
y = array_ops.reshape(
gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
return y * y
# TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
# common `implicit_grad` function and put it in DistributionStrategy.
@ -173,10 +182,12 @@ class DistributionTestBase(test.TestCase):
ops.Graph().as_default(), \
self.cached_session(config=config) as sess, \
d.scope():
l = core.Dense(1, use_bias=False)
kernel = self._create_variable_like_keras_dense_layer(
name="kernel", shape=(1, 1), dtype=dtypes.float32)
def loss(x):
y = array_ops.reshape(l(x), []) - array_ops.identity(1.)
y = array_ops.reshape(
gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
return y * y
grad_fn = backprop.implicit_grad(loss)

View File

@ -890,13 +890,13 @@ class Context(object):
@property
def executor(self):
ensure_initialized()
self.ensure_initialized()
return executor.Executor(
pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
@executor.setter
def executor(self, e):
ensure_initialized()
self.ensure_initialized()
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
@property

View File

@ -1192,7 +1192,7 @@ class _TapeGradientFunctions(object):
def _wrap_backward_function(self, forward_graph, backward, outputs):
"""Create a backward function given `outputs` from the forward function."""
capture_mapping = dict(
zip([ops.tensor_id(t) for t in forward_graph.outputs], outputs))
zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
remapped_captures = [
capture_mapping.get(ops.tensor_id(capture), capture)
for capture in backward.captured_inputs
@ -1489,9 +1489,8 @@ class ConcreteFunction(object):
self._captured_closures = self._func_graph.deferred_external_captures
structured_outputs = self._func_graph.structured_outputs
self._ndarrays_list = (
isinstance(structured_outputs, (list, tuple)) and
structured_outputs and
all([isinstance(o, np_arrays.ndarray) for o in structured_outputs]))
isinstance(structured_outputs, (list, tuple)) and structured_outputs and
all(isinstance(o, np_arrays.ndarray) for o in structured_outputs))
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
# function_spec defines the structured signature.
@ -2199,6 +2198,14 @@ class ConcreteFunction(object):
assert self._function_spec is not None
arg_specs, kwarg_specs = self.structured_input_signature
arg_names = list(self._function_spec.arg_names)
# If an explicit input_signature is provided to @tf.function, then any
# arguments with defaults that are not covered by that explicit signature
# are simply dropped from the signature.
# TODO(b/159639913) Look into whether dropping arguments with default values
# from the signature is the right thing to do.
arg_names = arg_names[:len(arg_specs)]
if default_values:
for i in range(len(arg_names)):
if not _contains_type_spec(arg_specs[i]):
@ -2248,6 +2255,14 @@ class ConcreteFunction(object):
lines = [self._structured_signature_summary(default_values=True)]
arg_specs, kwarg_specs = self.structured_input_signature
names = list(self._function_spec.arg_names)
# If an explicit input_signature is provided to @tf.function, then any
# arguments with defaults that are not covered by that explicit signature
# are simply dropped from the signature.
# TODO(b/159639913) Look into whether dropping arguments with default values
# from the signature is the right thing to do.
names = names[:len(arg_specs)]
names.extend(sorted(kwarg_specs))
specs = list(arg_specs) + list(kwarg_specs.values())
# note: we can skip bound args, since we already displayed thier bound
@ -2855,7 +2870,6 @@ class Function(object):
graph_function, _, _ = self._maybe_define_function(args, kwargs)
return graph_function
# XX TODO: make sure we fix up this path as well!?
def _get_concrete_function_internal(self, *args, **kwargs):
"""Bypasses error checking when getting a graph function."""
graph_function = self._get_concrete_function_internal_garbage_collected(

Some files were not shown because too many files have changed in this diff Show More