Merge branch 'master' into addsub_16x8

This commit is contained in:
Elena Zhelezina 2020-06-24 19:12:39 +01:00 committed by GitHub
commit 1bcc3bc41c
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
110 changed files with 2320 additions and 1616 deletions

View File

@ -54,7 +54,7 @@ Status ProcessInputs(
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
input_tensors->reserve(ninputs); input_tensors->reserve(ninputs);
for (int i = 0; i < ninputs; ++i) { for (int i = 0; i < ninputs; ++i) {
Node* node = &inputs[i].oper->node; Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
int idx = inputs[i].index; int idx = inputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR( TF_RETURN_WITH_CONTEXT_IF_ERROR(
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) { TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
output_tensors->reserve(noutputs); output_tensors->reserve(noutputs);
for (int i = 0; i < noutputs; ++i) { for (int i = 0; i < noutputs; ++i) {
Node* node = &outputs[i].oper->node; Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
int idx = outputs[i].index; int idx = outputs[i].index;
TF_RETURN_WITH_CONTEXT_IF_ERROR( TF_RETURN_WITH_CONTEXT_IF_ERROR(
fn_body->graph.IsValidOutputTensor(node, idx), fn_body->graph.IsValidOutputTensor(node, idx),

View File

@ -19,6 +19,7 @@ tf_cc_shared_object(
cc_library( cc_library(
name = "gcs_filesystem_impl", name = "gcs_filesystem_impl",
srcs = ["gcs_filesystem.cc"], srcs = ["gcs_filesystem.cc"],
hdrs = ["gcs_filesystem.h"],
copts = select({ copts = select({
"//conditions:default": [], "//conditions:default": [],
"//tensorflow:windows": get_win_copts(), "//tensorflow:windows": get_win_copts(),
@ -55,14 +56,9 @@ tf_cc_test(
"notap", "notap",
], ],
deps = [ deps = [
":gcs_helper", ":gcs_filesystem_impl",
"//tensorflow/c:env",
"//tensorflow/c:tf_status",
"//tensorflow/c:tf_status_helper", "//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/filesystem:filesystem_interface",
"//tensorflow/core/platform:stacktrace_handler", "//tensorflow/core/platform:stacktrace_handler",
"//tensorflow/core/platform:test", "//tensorflow/core/platform:test",
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
"@com_google_absl//absl/strings",
], ],
) )

View File

@ -12,26 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include <stdlib.h> #include <stdlib.h>
#include <string.h> #include <string.h>
#include <fstream>
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h" #include "google/cloud/storage/client.h"
#include "tensorflow/c/env.h" #include "tensorflow/c/env.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
#include "tensorflow/c/tf_status.h" #include "tensorflow/c/tf_status.h"
#ifdef TF_GCS_FILESYSTEM_TEST
// For testing purpose, we expose some functions.
#define TF_STATIC
#else
// Otherwise, we don't expose any symbol.
#define TF_STATIC static
#endif
// Implementation of a filesystem for GCS environments. // Implementation of a filesystem for GCS environments.
// This filesystem will support `gs://` URI schemes. // This filesystem will support `gs://` URI schemes.
namespace gcs = google::cloud::storage; namespace gcs = google::cloud::storage;
@ -48,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); } static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
static void plugin_memory_free(void* ptr) { free(ptr); } static void plugin_memory_free(void* ptr) { free(ptr); }
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok, void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** bucket, char** object, TF_Status* status) { char** object, TF_Status* status) {
size_t scheme_end = fname.find("://") + 2; size_t scheme_end = fname.find("://") + 2;
if (fname.substr(0, scheme_end + 1) != "gs://") { if (fname.substr(0, scheme_end + 1) != "gs://") {
TF_SetStatus(status, TF_INVALID_ARGUMENT, TF_SetStatus(status, TF_INVALID_ARGUMENT,
@ -130,7 +120,7 @@ namespace tf_read_only_memory_region {
namespace tf_gcs_filesystem { namespace tf_gcs_filesystem {
// TODO(vnvo2409): Add lazy-loading and customizing parameters. // TODO(vnvo2409): Add lazy-loading and customizing parameters.
TF_STATIC void Init(TF_Filesystem* filesystem, TF_Status* status) { void Init(TF_Filesystem* filesystem, TF_Status* status) {
google::cloud::StatusOr<gcs::Client> client = google::cloud::StatusOr<gcs::Client> client =
gcs::Client::CreateDefaultClient(); gcs::Client::CreateDefaultClient();
if (!client) { if (!client) {
@ -143,13 +133,13 @@ TF_STATIC void Init(TF_Filesystem* filesystem, TF_Status* status) {
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
static void Cleanup(TF_Filesystem* filesystem) { void Cleanup(TF_Filesystem* filesystem) {
plugin_memory_free(filesystem->plugin_filesystem); plugin_memory_free(filesystem->plugin_filesystem);
} }
// TODO(vnvo2409): Implement later // TODO(vnvo2409): Implement later
static void NewWritableFile(const TF_Filesystem* filesystem, const char* path, void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) { TF_WritableFile* file, TF_Status* status) {
char* bucket; char* bucket;
char* object; char* object;
@ -166,7 +156,7 @@ static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_SetStatus(status, TF_OK, ""); TF_SetStatus(status, TF_OK, "");
} }
static void NewAppendableFile(const TF_Filesystem* filesystem, const char* path, void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status) { TF_WritableFile* file, TF_Status* status) {
char* bucket; char* bucket;
char* object; char* object;

View File

@ -0,0 +1,35 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
#include "absl/strings/string_view.h"
#include "google/cloud/storage/client.h"
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
#include "tensorflow/c/tf_status.h"
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
char** object, TF_Status* status);
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
void Cleanup(TF_Filesystem* filesystem);
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
TF_WritableFile* file, TF_Status* status);
} // namespace tf_gcs_filesystem
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_

View File

@ -12,18 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h" #include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
#include "tensorflow/c/tf_status_helper.h" #include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/stacktrace_handler.h" #include "tensorflow/core/platform/stacktrace_handler.h"
#include "tensorflow/core/platform/test.h" #include "tensorflow/core/platform/test.h"
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x)) #define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
// Forward declaration
namespace tf_gcs_filesystem {
void Init(TF_Filesystem* filesystem, TF_Status* status);
}
namespace tensorflow { namespace tensorflow {
namespace { namespace {
@ -38,7 +34,7 @@ class GCSFilesystemTest : public ::testing::Test {
} }
void TearDown() override { void TearDown() override {
TF_DeleteStatus(status_); TF_DeleteStatus(status_);
// TODO(vnvo2409): Add filesystem cleanup tf_gcs_filesystem::Cleanup(filesystem_);
delete filesystem_; delete filesystem_;
} }

View File

@ -138,6 +138,11 @@ bool IsI32Type(Type element_type) {
return element_type.isInteger(32) && !element_type.isUnsignedInteger(); return element_type.isInteger(32) && !element_type.isUnsignedInteger();
} }
// Return true when the given element_type is I64.
bool IsI64Type(Type element_type) {
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
}
// Return true if the given Add operation has the CPU kernel supported shapes. // Return true if the given Add operation has the CPU kernel supported shapes.
bool VerifyAddOpShapeConstraints(AddOp op) { bool VerifyAddOpShapeConstraints(AddOp op) {
auto element_type = getElementTypeOrSelf(op.output().getType()); auto element_type = getElementTypeOrSelf(op.output().getType());
@ -174,7 +179,8 @@ bool VerifySubOpShapeConstraints(SubOp op) {
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes, // Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
// which are broadcastable shapes up to five dimension or have same shapes. // which are broadcastable shapes up to five dimension or have same shapes.
if (element_type.isF32() || IsI32Type(element_type) || if (element_type.isF32() || IsI32Type(element_type) ||
IsQUI8Type(element_type) || IsQI16Type(element_type)) { IsI64Type(element_type) || IsQUI8Type(element_type) ||
IsQI16Type(element_type)) {
return VerifyOperandsHaveSameShapesOrBroadcastableShape( return VerifyOperandsHaveSameShapesOrBroadcastableShape(
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1}, /*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
/*max_bcast_rank=*/5); /*max_bcast_rank=*/5);

View File

@ -2864,11 +2864,11 @@ def TFL_SubOp : TFL_Op<"sub", [
}]; }];
let arguments = ( let arguments = (
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs, ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs, TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
TFL_AFAttr:$fused_activation_function); TFL_AFAttr:$fused_activation_function);
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output); let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
let hasFolder = 1; let hasFolder = 1;

View File

@ -9,6 +9,15 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
// CHECK: return // CHECK: return
} }
func @sub(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>) -> tensor<1xi64> {
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
return %0: tensor<1xi64>
// CHECK-LABEL: sub
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi64>
// CHECK: return
}
// CHECK-LABEL: testAddHighDimsHaveSameShape // CHECK-LABEL: testAddHighDimsHaveSameShape
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> { func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"} // CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}

View File

@ -269,6 +269,14 @@ func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
return %0#0 : tensor<? x i32> return %0#0 : tensor<? x i32>
} }
// CHECK-LABEL: testSubInt64
func @testSubInt64(tensor<? x i64>, tensor<? x i64>) -> tensor<? x i64> {
^bb0(%arg0: tensor<? x i64>, %arg1: tensor<? x i64>):
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i64>
return %0#0 : tensor<? x i64>
}
// CHECK-LABEL: testMul // CHECK-LABEL: testMul
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> { func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>): ^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):

View File

@ -346,6 +346,7 @@ func @replication(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<f32>) ->
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate // CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
// CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1> // CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
// CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32> // CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
// CHECK-NOT: _replicated_input_indices
// CHECK-SAME: n = 2 : i32 // CHECK-SAME: n = 2 : i32
// CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( { // CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( {
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]]) // CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
@ -382,6 +383,46 @@ func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<
// CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}} // CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
// CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}} // CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}}
// CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}} // CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _replicated_input_indices = [0, 1, 2, -1, -1, -1]
// Test TPUReplicatedInputs with non contiguous `index` attributes.
// CHECK-LABEL: func @non_contigous_indices
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>)
func @non_contigous_indices(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 8 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
%1 = "tf.TPUReplicatedInput"(%arg1, %arg1) : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opB"(%1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
%2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opC"(%2) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _replicated_input_indices = [2, 8, -1]
// Test that the `is_mirrored_variable` attribute is preserved in the
// tf_device.replicate op.
// CHECK-LABEL: func @mirrored_variables
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_2:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_3:.*]]: tensor<!tf.resource<tensor<32xf32>>>)
func @mirrored_variables(%arg0: tensor<!tf.resource<tensor<32xf32>>>, %arg1: tensor<!tf.resource<tensor<32xf32>>>, %arg2: tensor<!tf.resource<tensor<32xf32>>>, %arg3: tensor<!tf.resource<tensor<32xf32>>>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
%1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device"} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _mirrored_variable_indices = [1]
// CHECK-SAME: _replicated_input_indices = [0, 1]
// ----- // -----
@ -407,8 +448,10 @@ func @bad_num_replicas() {
return return
} }
// ----- // -----
// Test that functions without TPUReplicateMetadata op are skipped without // Test that functions without TPUReplicateMetadata op are skipped without
// error // error
// CHECK-LABEL: func @missing_metadata_op // CHECK-LABEL: func @missing_metadata_op
@ -483,22 +526,9 @@ func @leftover_replicated_output(%arg0: tensor<i1>) {
// ----- // -----
// Test bad TPUReplicatedInput positive `index` attribute.
func @bad_positive_index_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got 1}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// -----
// Test bad TPUReplicatedInput negative `index` attribute. // Test bad TPUReplicatedInput negative `index` attribute.
func @bad_negative_index_input(%arg0: tensor<i1>) { func @bad_negative_index_input(%arg0: tensor<i1>) {
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got -2}} // expected-error@+1 {{'tf.TPUReplicatedInput' op requires index to be at least -1, but got -2}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1> %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> () "tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
@ -509,33 +539,12 @@ func @bad_negative_index_input(%arg0: tensor<i1>) {
// ----- // -----
// Test TPUReplicatedInput with conflicting `index` attribute. This will result // Test TPUReplicatedInput with conflicting `index` attribute.
// in gaps in the TPUReplicatedInput ordering.
func @input_index_gaps(%arg0: tensor<i1>) { func @input_index_gaps(%arg0: tensor<i1>) {
// expected-error@+1 {{failed to sort 'tf.TPUReplicatedInput' ops, gap(s) found in indices}}
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1> %0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires indices to be unique, but found multiple 'tf.TPUReplicatedInput' ops with index 1}}
%1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1> %1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> () "tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> () "tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return return
} }
// -----
// Test that the `is_mirrored_variable` attribute is preserved in the
// tf_device.replicate op.
// CHECK-LABEL: func @mirrored_variables
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_2:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_3:.*]]: tensor<!tf.resource<tensor<32xf32>>>)
func @mirrored_variables(%arg0: tensor<!tf.resource<tensor<32xf32>>>, %arg1: tensor<!tf.resource<tensor<32xf32>>>, %arg2: tensor<!tf.resource<tensor<32xf32>>>, %arg3: tensor<!tf.resource<tensor<32xf32>>>) {
%0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
%1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device"} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> ()
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
return
}
// CHECK: tf_device.replicate
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}}
// CHECK-SAME: _mirrored_variable_indices = [1]

View File

@ -9,7 +9,7 @@
// padding_arg_index: 1 // padding_arg_index: 1
// CHECK-LABEL: func @single_arg_single_shape // CHECK-LABEL: func @single_arg_single_shape
func @single_arg_single_shape(%arg0: tensor<i1>) { func @single_arg_single_shape(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func0, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_0, %ri_1) {func = @func0, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -36,7 +36,7 @@ func @func0(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// padding_arg_index: 2 // padding_arg_index: 2
// CHECK-LABEL: func @single_arg_multiple_shapes // CHECK-LABEL: func @single_arg_multiple_shapes
func @single_arg_multiple_shapes(%arg0: tensor<i1>) { func @single_arg_multiple_shapes(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>) {_replicated_input_indices = [0, 1, 2], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1, %ri_2) {func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_0, %ri_1, %ri_2) {func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -68,7 +68,7 @@ func @func1(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// padding_arg_index: 3 // padding_arg_index: 3
// CHECK-LABEL: func @multiple_args // CHECK-LABEL: func @multiple_args
func @multiple_args(%arg0: tensor<i1>) { func @multiple_args(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>, [%arg0, %arg0] as %ri_4: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>, [%arg0, %arg0] as %ri_4: tensor<i1>) {_replicated_input_indices = [0, 1, 2, 3, 4], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -89,7 +89,7 @@ func @func2(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tens
// padding_arg_index: 1 // padding_arg_index: 1
// CHECK-LABEL: func @remap_indices // CHECK-LABEL: func @remap_indices
func @remap_indices(%arg0: tensor<i1>) { func @remap_indices(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func3, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func3, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -124,7 +124,7 @@ func @func4(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// Test encapsulated function is not modified when there are no padding maps. // Test encapsulated function is not modified when there are no padding maps.
// CHECK-LABEL: func @no_padding_map // CHECK-LABEL: func @no_padding_map
func @no_padding_map(%arg0: tensor<i1>) { func @no_padding_map(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func5} : (tensor<i1>, tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func5} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -140,7 +140,7 @@ func @func5(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// Test encapsulated function is not modified when padding maps is empty. // Test encapsulated function is not modified when padding maps is empty.
// CHECK-LABEL: func @empty_padding_map // CHECK-LABEL: func @empty_padding_map
func @empty_padding_map(%arg0: tensor<i1>) { func @empty_padding_map(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func6, padding_map = []} : (tensor<i1>, tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func6, padding_map = []} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -161,7 +161,7 @@ func @func6(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
// padding_arg_index: 1 // padding_arg_index: 1
// CHECK-LABEL: func @unused_padding_map // CHECK-LABEL: func @unused_padding_map
func @unused_padding_map(%arg0: tensor<i1>) { func @unused_padding_map(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
"tf_device.cluster_func"(%ri_1) {func = @func7, padding_map = ["\10\02\18\01"]} : (tensor<i1>) -> () "tf_device.cluster_func"(%ri_1) {func = @func7, padding_map = ["\10\02\18\01"]} : (tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -187,7 +187,7 @@ func @func7(%arg0: tensor<i1>) {
// shape_index: 2 // shape_index: 2
// padding_arg_index: 3 // padding_arg_index: 3
func @missing_padding_arg(%arg0: tensor<i1>) { func @missing_padding_arg(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>) {_replicated_input_indices = [0, 1, 2, 3], n = 2 : i32} {
// expected-warning@+1 {{bad 'padding_map' attribute at index 0, unused padding_arg_index 1}} // expected-warning@+1 {{bad 'padding_map' attribute at index 0, unused padding_arg_index 1}}
"tf_device.cluster_func"(%ri_0, %ri_2, %ri_3) {func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_0, %ri_2, %ri_3) {func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
@ -201,11 +201,55 @@ func @func8(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
return return
} }
// Test tf_device.replicate with missing _replicated_input_indices does no
// transformation.
//
// Padding map "\10\02\18\01":
// arg_index: 0
// shape_index: 2
// padding_arg_index: 1
// CHECK-LABEL: func @missing_replicated_input_indices
func @missing_replicated_input_indices(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func9, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func @func9
// CHECK-NOT: xla_hlo.padding_map
func @func9(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
// Test single argument with padding map lifted to associated encapsulated
// function.
//
// Padding map "\08\08\10\06\18\02"
// arg_index: 8
// shape_index: 6
// padding_arg_index: 2
// CHECK-LABEL: func @non_contigous_indices
func @non_contigous_indices(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [2, 8], n = 2 : i32} {
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func10, padding_map = ["\08\08\10\06\18\02"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
// CHECK-LABEL: func @func10
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}})
func @func10(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
// ----- // -----
// Test bad padding map attribute (not an array). // Test bad padding map attribute (not an array).
func @bad_padding_map() { func @bad_padding_map() {
tf_device.replicate {n = 2 : i32} { tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op requires 'padding_map' array attribute}} // expected-error@+1 {{'tf_device.cluster_func' op requires 'padding_map' array attribute}}
"tf_device.cluster_func"() {func = @_func, padding_map = 0 : i32} : () -> () "tf_device.cluster_func"() {func = @_func, padding_map = 0 : i32} : () -> ()
tf_device.return tf_device.return
@ -221,7 +265,7 @@ func @_func() {
// Test bad padding map attribute (element in array is not a string). // Test bad padding map attribute (element in array is not a string).
func @bad_padding_map_element() { func @bad_padding_map_element() {
tf_device.replicate {n = 2 : i32} { tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, not a string}} // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, not a string}}
"tf_device.cluster_func"() {func = @_func, padding_map = [0 : i32]} : () -> () "tf_device.cluster_func"() {func = @_func, padding_map = [0 : i32]} : () -> ()
tf_device.return tf_device.return
@ -237,7 +281,7 @@ func @_func() {
// Test unparsable padding map. // Test unparsable padding map.
func @bad_padding_map_proto() { func @bad_padding_map_proto() {
tf_device.replicate {n = 2 : i32} { tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}} // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}}
"tf_device.cluster_func"() {func = @_func, padding_map = ["z"]} : () -> () "tf_device.cluster_func"() {func = @_func, padding_map = ["z"]} : () -> ()
tf_device.return tf_device.return
@ -258,8 +302,8 @@ func @_func() {
// shape_index: 2 // shape_index: 2
// padding_arg_index: 1 // padding_arg_index: 1
func @negative_arg_index(%arg0: tensor<i1>) { func @negative_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}} // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be nonnegative, but got -1}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -272,27 +316,6 @@ func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// ----- // -----
// Test out of bound arg index.
//
// Padding map "\08\02\10\02\18\01":
// arg_index: 2
// shape_index: 2
// padding_arg_index: 1
func @bad_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}
// -----
// Test negative padding arg index. // Test negative padding arg index.
// //
// Padding map "\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01": // Padding map "\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01":
@ -300,8 +323,8 @@ func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
// shape_index: 2 // shape_index: 2
// padding_arg_index: -1 // padding_arg_index: -1
func @negative_padding_arg_index(%arg0: tensor<i1>) { func @negative_padding_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} { tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}} // expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be nonnegative, but got -1}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor<i1>, tensor<i1>) -> () "tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return tf_device.return
} }
@ -311,24 +334,3 @@ func @negative_padding_arg_index(%arg0: tensor<i1>) {
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) { func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return return
} }
// -----
// Test out of bound padding arg index.
//
// Padding map "\08\01\10\02\18\02":
// arg_index: 1
// shape_index: 2
// padding_arg_index: 2
func @bad_padding_arg_index(%arg0: tensor<i1>) {
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}}
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor<i1>, tensor<i1>) -> ()
tf_device.return
}
return
}
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
return
}

View File

@ -2,18 +2,7 @@
// Tests that missing `_xla_outside_compilation` attribute value results in an error. // Tests that missing `_xla_outside_compilation` attribute value results in an error.
func @missing_outside_compilation_attribute() -> () { module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
"tf_device.cluster"() ( {
"tf.A"() : () -> ()
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> ()
return
}
// -----
// Tests that TPU cluster with no outside compilation does not generate parallel_execute. // Tests that TPU cluster with no outside compilation does not generate parallel_execute.
// CHECK-LABEL: func @no_outside_compilation // CHECK-LABEL: func @no_outside_compilation
@ -22,7 +11,7 @@ func @no_outside_compilation() -> tensor<?xi32> {
%1 = "tf.A"() : () -> tensor<?xi32> %1 = "tf.A"() : () -> tensor<?xi32>
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32> %2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
return %0 : tensor<?xi32> return %0 : tensor<?xi32>
} }
@ -36,15 +25,17 @@ func @nodep_single_outside_compilation() -> () {
// CHECK-NEXT: "tf_device.launch" // CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B" // CHECK-NEXT: "tf.B"
// CHECK-NOT: _xla_outside_compilation // CHECK-NOT: _xla_outside_compilation
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
// CHECK: "tf_device.cluster" // CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A" // CHECK-NEXT: "tf.A"
// CHECK: cluster_attr = "cluster_attr" // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( { "tf_device.cluster"() ( {
"tf.A"() : () -> () "tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.C"() : () -> () "tf.C"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return return
} }
@ -61,7 +52,7 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
// CHECK: "tf_device.cluster" // CHECK: "tf_device.cluster"
// CHECK-NEXT: "tf.A" // CHECK-NEXT: "tf.A"
// CHECK-NEXT: "tf.E" // CHECK-NEXT: "tf.E"
// CHECK: cluster_attr = "cluster_attr" // CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
"tf_device.cluster"() ( { "tf_device.cluster"() ( {
"tf.A"() : () -> () "tf.A"() : () -> ()
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
@ -69,7 +60,7 @@ func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
"tf.E"() : () -> () "tf.E"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return return
} }
@ -87,7 +78,7 @@ func @nodep_multiple_outside_compilation() -> () {
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> () "tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
"tf.E"() : () -> () "tf.E"() : () -> ()
tf_device.return tf_device.return
}) {cluster_attr = "cluster_attr"} : () -> () }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
return return
} }
@ -99,6 +90,9 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tens
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate // CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute" // CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
// CHECK-NEXT: "tf_device.launch" // CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.B"
// CHECK-NEXT: tf_device.return
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster" // CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
// CHECK: tf_device.return // CHECK: tf_device.return
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]] // CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
@ -109,7 +103,7 @@ func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tens
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%3 = "tf.C"() : () -> tensor<?xi32> %3 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %3 : tensor<?xi32> tf_device.return %3 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -134,7 +128,7 @@ func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> te
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> () "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
%5 = "tf.C"() : () -> tensor<?xi32> %5 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32> tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xf32>, tensor<?xi32>) }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xf32>, tensor<?xi32>)
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32> tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
} }
@ -163,7 +157,7 @@ func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> () "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32> %4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -194,7 +188,7 @@ func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?x
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>) %4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32> %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32> tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -224,7 +218,7 @@ func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>) %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>) %5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -255,7 +249,7 @@ func @single_outside_compiled_input_output_single_outside_compilation(%arg0: ten
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>) %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32> %5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32> tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -290,7 +284,7 @@ func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: t
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32> %7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32> %8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %8 : tensor<?xi32> tf_device.return %8 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -332,7 +326,7 @@ func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>) %6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32> %7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
tf_device.return %7 : tensor<?xi32> tf_device.return %7 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -361,7 +355,7 @@ func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> () "tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%4 = "tf.C"() : () -> tensor<?xi32> %4 = "tf.C"() : () -> tensor<?xi32>
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -399,7 +393,7 @@ func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?
%4 = "tf.C"() : () -> tensor<?xi32> %4 = "tf.C"() : () -> tensor<?xi32>
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> () "tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
tf_device.return %4 : tensor<?xi32> tf_device.return %4 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -432,7 +426,7 @@ func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> () "tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
%5 = "tf.E"() : () -> tensor<?xi32> %5 = "tf.E"() : () -> tensor<?xi32>
tf_device.return %5 : tensor<?xi32> tf_device.return %5 : tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32> }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
tf_device.return %2 : tensor<?xi32> tf_device.return %2 : tensor<?xi32>
} }
@ -454,8 +448,9 @@ func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>) %4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>) %5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32> tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xi32>, tensor<?xi32>) }) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xi32>, tensor<?xi32>)
tf_device.return %2#1 : tensor<?xi32> tf_device.return %2#1 : tensor<?xi32>
} }
return %1 : tensor<?xi32> return %1 : tensor<?xi32>
} }
}

View File

@ -22,6 +22,7 @@ limitations under the License.
// not have ops outside of the cluster that are both operands and results of the // not have ops outside of the cluster that are both operands and results of the
// cluster. Note, this currently does not handle side effecting ops yet. // cluster. Note, this currently does not handle side effecting ops yet.
#include <algorithm>
#include <iterator> #include <iterator>
#include <memory> #include <memory>
#include <tuple> #include <tuple>
@ -29,6 +30,7 @@ limitations under the License.
#include "llvm/ADT/ArrayRef.h" #include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/DenseMap.h" #include "llvm/ADT/DenseMap.h"
#include "llvm/ADT/DenseSet.h"
#include "llvm/ADT/STLExtras.h" #include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h" #include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallVector.h" #include "llvm/ADT/SmallVector.h"
@ -59,6 +61,7 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
constexpr char kDeviceAttr[] = "device"; constexpr char kDeviceAttr[] = "device";
constexpr char kNameAttr[] = "name"; constexpr char kNameAttr[] = "name";
constexpr char kNumReplicasAttr[] = "num_replicas"; constexpr char kNumReplicasAttr[] = "num_replicas";
constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices"; constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
constexpr char kBadTPUReplicateAttrMsg[] = constexpr char kBadTPUReplicateAttrMsg[] =
@ -261,33 +264,42 @@ void MovePrecedingClusterUsers(tf_device::ClusterOp cluster,
// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index` // Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
// of -1 are always after ops with a non negative `index`, and an arbitrary // of -1 are always after ops with a non negative `index`, and an arbitrary
// ordering is used as there are no dependencies on their relative ordering. // ordering is used as there are no dependencies on their relative ordering. If
// there are multiple `tf.TPUReplicatedInput` ops with the same non negative
// index or if indices are less than -1, an error will be returned.
LogicalResult SortTPUReplicatedInputsByIndex( LogicalResult SortTPUReplicatedInputsByIndex(
llvm::ArrayRef<Operation*> inputs, llvm::ArrayRef<Operation*> inputs,
llvm::SmallVectorImpl<Operation*>* sorted_inputs) { llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
const int input_size = inputs.size(); llvm::SmallDenseSet<int64_t, 8> unique_indices;
sorted_inputs->resize(input_size, nullptr);
int last_index = input_size - 1;
for (Operation* input : inputs) { for (Operation* input : inputs) {
int64_t index = int64_t index =
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getLimitedValue(); llvm::cast<TF::TPUReplicatedInputOp>(input).index().getSExtValue();
if (index < -1)
if (index >= input_size || index < -1) return input->emitOpError()
return input->emitError() << "'" << input->getName().getStringRef() << "requires index to be at least -1, but got " << index;
<< "' index is not in range [-1, " << input_size if (index == -1) continue;
<< "), got " << index; if (!unique_indices.insert(index).second)
return input->emitOpError()
if (index == -1) << "requires indices to be unique, but found multiple '"
(*sorted_inputs)[last_index--] = input; << input->getName() << "' ops with index " << index;
else
(*sorted_inputs)[index] = input;
} }
if (llvm::any_of(*sorted_inputs, [](Operation* op) { return op == nullptr; })) // Sort all TPUReplicatedInputs by `index` attribute to have
return inputs.front()->emitError() // TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
<< "failed to sort '" << inputs.front()->getName().getStringRef() // deterministically. If `index` attribute is -1, instead move them to the
<< "' ops, gap(s) found in indices"; // end.
sorted_inputs->assign(inputs.begin(), inputs.end());
std::stable_sort(
sorted_inputs->begin(), sorted_inputs->end(),
[](Operation* l, Operation* r) {
int64_t l_index =
llvm::cast<TF::TPUReplicatedInputOp>(l).index().getSExtValue();
int64_t r_index =
llvm::cast<TF::TPUReplicatedInputOp>(r).index().getSExtValue();
if (l_index == -1 && r_index != -1) return false;
if (r_index == -1 && l_index != -1) return true;
return l_index < r_index;
});
return success(); return success();
} }
@ -315,6 +327,11 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
unique_replicated_input_ops.getArrayRef(), &replicated_input_ops))) unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
return failure(); return failure();
// Index attribute value stored on TPUReplicatedInput op. These will be used
// later for dynamic padder.
llvm::SmallVector<int64_t, 8> replicated_input_indices;
bool has_replicated_input_index = false;
// Indices of the replicate op's arguments that are mirrored variables. // Indices of the replicate op's arguments that are mirrored variables.
llvm::SmallVector<int64_t, 8> mirrored_variable_indices; llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
@ -330,7 +347,14 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
replicated_inputs.push_back( replicated_inputs.push_back(
{input->getOperands(), input->getOperand(0).getType()}); {input->getOperands(), input->getOperand(0).getType()});
if (llvm::cast<TF::TPUReplicatedInputOp>(input).is_mirrored_variable())
auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
int64_t tpu_replicated_input_index =
tpu_replicated_input.index().getSExtValue();
replicated_input_indices.push_back(tpu_replicated_input_index);
if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
if (tpu_replicated_input.is_mirrored_variable())
mirrored_variable_indices.push_back(pos_and_input.index()); mirrored_variable_indices.push_back(pos_and_input.index());
} }
@ -340,6 +364,10 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
cluster.getLoc(), num_replicas, cluster.getLoc(), num_replicas,
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(), llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
replicated_inputs, cluster.getResultTypes()); replicated_inputs, cluster.getResultTypes());
if (has_replicated_input_index)
replicate_op.setAttr(kReplicatedInputIndicesAttr,
builder.getI64ArrayAttr(replicated_input_indices));
if (!mirrored_variable_indices.empty()) if (!mirrored_variable_indices.empty())
replicate_op.setAttr(kMirroredVariableIndicesAttr, replicate_op.setAttr(kMirroredVariableIndicesAttr,
builder.getI64ArrayAttr(mirrored_variable_indices)); builder.getI64ArrayAttr(mirrored_variable_indices));

View File

@ -40,6 +40,7 @@ limitations under the License.
namespace mlir { namespace mlir {
namespace TFTPU { namespace TFTPU {
constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
constexpr char kPaddingMapAttr[] = "padding_map"; constexpr char kPaddingMapAttr[] = "padding_map";
// This pass remaps and assigns padding maps to an encapsulated function's // This pass remaps and assigns padding maps to an encapsulated function's
@ -56,14 +57,23 @@ struct TPUDynamicPaddingMapper
// Creates a mapping from replicated input index (in `tf_device.replicate` op) // Creates a mapping from replicated input index (in `tf_device.replicate` op)
// to `tf_device.cluster_func` operand index. // to `tf_device.cluster_func` operand index.
llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices( llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate) { tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate,
ArrayAttr replicated_input_indices_attr) {
Block* replicate_block = &replicate.GetBody(); Block* replicate_block = &replicate.GetBody();
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices; llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>()) if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>()) {
if (block_arg.getOwner() == replicate_block) if (block_arg.getOwner() == replicate_block) {
remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index(); int64_t replicated_input_index =
replicated_input_indices_attr[block_arg.getArgNumber()]
.cast<IntegerAttr>()
.getInt();
if (replicated_input_index != -1)
remapped_indices[replicated_input_index] = operand_and_idx.index();
}
}
}
return remapped_indices; return remapped_indices;
} }
@ -73,16 +83,15 @@ llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
// indices. An error will be returned if an index is not found or parsing // indices. An error will be returned if an index is not found or parsing
// failed. // failed.
LogicalResult GetRemappedPaddings( LogicalResult GetRemappedPaddings(
tf_device::ClusterFuncOp cluster_func, int num_replicated_args, tf_device::ClusterFuncOp cluster_func,
const llvm::SmallDenseMap<int32_t, int32_t>& remapped_indices, const llvm::SmallDenseMap<int32_t, int32_t>& remapped_indices,
llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap>* remapped_paddings) { llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap>* remapped_paddings) {
auto bad_index_msg = [num_replicated_args](int32_t index, auto bad_index_msg = [](int32_t index, llvm::StringRef arg_type,
llvm::StringRef arg_type,
int32_t arg_index) { int32_t arg_index) {
return llvm::formatv( return llvm::formatv(
"bad '{0}' attribute at index {1}, {2} must be in [0, {3}), got " "bad '{0}' attribute at index {1}, {2} must be nonnegative, but "
"{4}", "got {3}",
kPaddingMapAttr, index, arg_type, num_replicated_args, arg_index) kPaddingMapAttr, index, arg_type, arg_index)
.str(); .str();
}; };
@ -111,12 +120,12 @@ LogicalResult GetRemappedPaddings(
kPaddingMapAttr, idx, padding.getValue())); kPaddingMapAttr, idx, padding.getValue()));
const int32_t arg_index = padding_proto.arg_index(); const int32_t arg_index = padding_proto.arg_index();
if (arg_index >= num_replicated_args || arg_index < 0) if (arg_index < 0)
return cluster_func.emitOpError() return cluster_func.emitOpError()
<< bad_index_msg(idx, "arg_index", arg_index); << bad_index_msg(idx, "arg_index", arg_index);
const int32_t padding_arg_index = padding_proto.padding_arg_index(); const int32_t padding_arg_index = padding_proto.padding_arg_index();
if (padding_arg_index >= num_replicated_args || padding_arg_index < 0) if (padding_arg_index < 0)
return cluster_func.emitOpError() return cluster_func.emitOpError()
<< bad_index_msg(idx, "padding_arg_index", padding_arg_index); << bad_index_msg(idx, "padding_arg_index", padding_arg_index);
@ -175,17 +184,21 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>(); auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
// LaunchFunc is not replicated, there will be no padding. // LaunchFunc is not replicated, there will be no padding.
if (!replicate) return success(); if (!replicate) return success();
const int num_replicated_args = replicate.GetBody().getNumArguments();
auto func = symbol_table->lookup<FuncOp>(cluster_func.func()); auto func = symbol_table->lookup<FuncOp>(cluster_func.func());
if (!func) return success(); if (!func) return success();
auto replicated_input_indices_attr =
replicate.getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
if (!replicated_input_indices_attr) return success();
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices = llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
GetRemappedReplicatedInputIndices(cluster_func, replicate); GetRemappedReplicatedInputIndices(cluster_func, replicate,
replicated_input_indices_attr);
llvm::SmallVector<tensorflow::tpu::PaddingMap, 4> remapped_paddings; llvm::SmallVector<tensorflow::tpu::PaddingMap, 4> remapped_paddings;
if (failed(GetRemappedPaddings(cluster_func, num_replicated_args, if (failed(GetRemappedPaddings(cluster_func, remapped_indices,
remapped_indices, &remapped_paddings))) &remapped_paddings)))
return failure(); return failure();
AnnotateFunctionArgumentsWithPaddings(func, remapped_paddings); AnnotateFunctionArgumentsWithPaddings(func, remapped_paddings);

View File

@ -26,6 +26,8 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h" #include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h" #include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
namespace mlir { namespace mlir {
namespace TFTPU { namespace TFTPU {
@ -91,13 +93,14 @@ void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
// Creates a `tf_device::LaunchOp` to wrap cluster ops. // Creates a `tf_device::LaunchOp` to wrap cluster ops.
tf_device::LaunchOp CreateLaunchOpForOutsideCluster( tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
OpBuilder* builder, Operation* last_cluster_op) { OpBuilder* builder, Operation* last_cluster_op,
// TODO(b/154363171): Set the CPU device. llvm::StringRef host_device) {
// An empty string placeholder is used for the device as that will be later // An empty string placeholder is used for the device as that will be later
// populated with the device of the associated TPUReplicateMetadata op. // populated with the device of the associated TPUReplicateMetadata op.
llvm::SmallVector<Type, 8> result_types; llvm::SmallVector<Type, 8> result_types;
auto launch_op = builder->create<tf_device::LaunchOp>( auto launch_op = builder->create<tf_device::LaunchOp>(
last_cluster_op->getLoc(), builder->getStringAttr(""), result_types); last_cluster_op->getLoc(), builder->getStringAttr(host_device),
result_types);
launch_op.body().push_back(new Block); launch_op.body().push_back(new Block);
@ -253,8 +256,9 @@ void MoveOutsideCompiledOps(
// Creates a `parallel_execute` op in place of launch with 'clusters` and // Creates a `parallel_execute` op in place of launch with 'clusters` and
// 'launch` as regions. // 'launch` as regions.
void CreateParallelExecuteFromOutsideClusters( void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster,
tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) { const OutsideClusterMap& clusters,
llvm::StringRef host_device) {
OpBuilder builder(tpu_cluster); OpBuilder builder(tpu_cluster);
// Create parallel_execute regions. The original TPU cluster computation // Create parallel_execute regions. The original TPU cluster computation
// is the extra region. // is the extra region.
@ -269,8 +273,8 @@ void CreateParallelExecuteFromOutsideClusters(
Block& outside_block = Block& outside_block =
parallel_execute_op.GetRegionBlockWithIndex(cluster.index()); parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
builder.setInsertionPointToEnd(&outside_block); builder.setInsertionPointToEnd(&outside_block);
tf_device::LaunchOp host_launch_op = tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster(
CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back()); &builder, cluster_ops.back(), host_device);
// Determine if there are any inputs that are provided out of cluster. // Determine if there are any inputs that are provided out of cluster.
auto external_inputs = GetExternalOperands(cluster_ops); auto external_inputs = GetExternalOperands(cluster_ops);
@ -307,8 +311,14 @@ void CreateParallelExecuteFromOutsideClusters(
} }
void TPUExtractOutsideCompilation::runOnOperation() { void TPUExtractOutsideCompilation::runOnOperation() {
// Get runtime devices information from the closest parent module.
auto module = getOperation();
mlir::TF::RuntimeDevices devices;
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
return signalPassFailure();
auto extract_result = auto extract_result =
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) { module.walk([&](tf_device::ClusterOp tpu_cluster) {
OutsideClusterMap clusters; OutsideClusterMap clusters;
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(), if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
&clusters))) &clusters)))
@ -316,7 +326,11 @@ void TPUExtractOutsideCompilation::runOnOperation() {
if (clusters.empty()) return WalkResult::advance(); if (clusters.empty()) return WalkResult::advance();
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters); std::string host_device;
tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
&host_device);
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters,
host_device);
return WalkResult::advance(); return WalkResult::advance();
}); });

View File

@ -91,6 +91,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client); explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
PjRtClient* pjrt_client() const { return pjrt_client_.get(); } PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; }
const std::string& platform_name() const { const std::string& platform_name() const {
return pjrt_client_->platform_name(); return pjrt_client_->platform_name();

View File

@ -34,7 +34,7 @@ END
description: <<END description: <<END
The type of padding algorithm to use. The type of padding algorithm to use.
We specify the size-related attributes as: The size-related attributes are specified as follows:
```python ```python
ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1] ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
@ -42,5 +42,5 @@ We specify the size-related attributes as:
``` ```
END END
} }
summary: "Extract `patches` from `input` and put them in the \"depth\" output dimension. 3D extension of `extract_image_patches`." summary: "Extract `patches` from `input` and put them in the `\"depth\"` output dimension. 3D extension of `extract_image_patches`."
} }

View File

@ -9,10 +9,9 @@ def _lookup_file(filegroup, path):
return file return file
return None return None
def _gen_kernel_image_hdr_impl(ctx): CubinInfo = provider(fields = ["cubins"])
if not ctx.attr.gpu_archs:
fail("No GPU architecture specified, use --config=cuda or similar")
def _gen_kernel_cubin_impl(ctx):
name = ctx.attr.name name = ctx.attr.name
tile_sizes = ctx.attr.tile_size.replace("x", ",") tile_sizes = ctx.attr.tile_size.replace("x", ",")
cmd_args = [] cmd_args = []
@ -22,7 +21,6 @@ def _gen_kernel_image_hdr_impl(ctx):
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors) cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
cubins = [] cubins = []
images = []
for arch in ctx.attr.gpu_archs: for arch in ctx.attr.gpu_archs:
# TODO(b/152737872): 'compute_' should generate both SASS and PTX. # TODO(b/152737872): 'compute_' should generate both SASS and PTX.
arch = arch.replace("compute_", "sm_") arch = arch.replace("compute_", "sm_")
@ -41,13 +39,36 @@ def _gen_kernel_image_hdr_impl(ctx):
mnemonic = "compile", mnemonic = "compile",
) )
cubins.append(cubin) cubins.append(cubin)
return [CubinInfo(cubins = cubins)]
_gen_kernel_cubin_rule = rule(
implementation = _gen_kernel_cubin_impl,
attrs = {
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
"tile_size": attr.string(mandatory = True),
"same_shape": attr.string(),
"unroll_factors": attr.string(),
"gpu_archs": attr.string_list(mandatory = True),
"_tool": attr.label(
executable = True,
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
cfg = "host",
),
},
output_to_genfiles = True,
)
def _gen_kernel_image_hdr_impl(ctx):
images = []
for cubin in ctx.attr.input[CubinInfo].cubins:
arch = cubin.path.split(".")[-2]
images.append("--image=profile=%s,file=%s" % (arch, cubin.path)) images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
# Generate fatbin file from all cubins. # Generate fatbin file from all cubins.
fatbin = ctx.actions.declare_file("%s.fatbin" % name) fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
ctx.actions.run( ctx.actions.run(
outputs = [fatbin], outputs = [fatbin],
inputs = cubins, inputs = ctx.attr.input[CubinInfo].cubins,
executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"), executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
arguments = [ arguments = [
"--64", "--64",
@ -73,37 +94,31 @@ _gen_kernel_image_hdr_rule = rule(
implementation = _gen_kernel_image_hdr_impl, implementation = _gen_kernel_image_hdr_impl,
output_to_genfiles = True, output_to_genfiles = True,
attrs = { attrs = {
"mlir_op": attr.label(mandatory = True, allow_single_file = True), "input": attr.label(mandatory = True, providers = [CubinInfo]),
"tile_size": attr.string(mandatory = True),
"same_shape": attr.string(),
"unroll_factors": attr.string(),
"out": attr.output(mandatory = True), "out": attr.output(mandatory = True),
"symbol": attr.string(mandatory = True), "symbol": attr.string(mandatory = True),
"gpu_archs": attr.string_list(mandatory = True),
"_cuda_root": attr.label( "_cuda_root": attr.label(
default = Label("@local_config_cuda//cuda:cuda_root"), default = Label("@local_config_cuda//cuda:cuda_root"),
), ),
"_tool": attr.label(
executable = True,
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
cfg = "host",
),
}, },
) )
def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None, unroll_factors = None): def _gen_kernel_image_hdr(name, mlir_op, tile_size, same_shape = None, unroll_factors = None):
"""Generates a C header with fatbin data from a Tensorflow op.""" """Generates a C header with fatbin data from a Tensorflow op."""
if cuda_gpu_architectures(): if cuda_gpu_architectures():
_gen_kernel_image_hdr_rule( _gen_kernel_cubin_rule(
name = name, name = name + "_cubin",
mlir_op = mlir_op, mlir_op = mlir_op,
tile_size = tile_size, tile_size = tile_size,
same_shape = same_shape, same_shape = same_shape,
unroll_factors = unroll_factors, unroll_factors = unroll_factors,
gpu_archs = cuda_gpu_architectures(),
)
_gen_kernel_image_hdr_rule(
name = name,
input = ":" + name + "_cubin",
out = "%s.h" % name, out = "%s.h" % name,
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""), symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
gpu_archs = cuda_gpu_architectures(),
tags = tags,
) )
def _gen_mlir_op_impl(ctx): def _gen_mlir_op_impl(ctx):
@ -157,7 +172,6 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr
name = "{name}_{type}_kernel".format(name = name, type = type), name = "{name}_{type}_kernel".format(name = name, type = type),
mlir_op = "{name}_{type}.mlir".format(name = name, type = type), mlir_op = "{name}_{type}.mlir".format(name = name, type = type),
tile_size = tile_size, tile_size = tile_size,
tags = tags,
same_shape = same_shape, same_shape = same_shape,
unroll_factors = unroll_factors, unroll_factors = unroll_factors,
) )

View File

@ -108,7 +108,7 @@ limitations under the License.
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0 #define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0 #define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
#define TF_GRAPH_DEF_VERSION 441 // Updated: 2020/6/23 #define TF_GRAPH_DEF_VERSION 442 // Updated: 2020/6/24
// Checkpoint compatibility versions (the versions field in SavedSliceMeta). // Checkpoint compatibility versions (the versions field in SavedSliceMeta).
// //

View File

@ -28,10 +28,12 @@ cc_library(
srcs = ["tpu_compile_op_common.cc"], srcs = ["tpu_compile_op_common.cc"],
hdrs = ["tpu_compile_op_common.h"], hdrs = ["tpu_compile_op_common.h"],
deps = [ deps = [
":tpu_compile_op_options",
":tpu_compile_op_support", ":tpu_compile_op_support",
":tpu_mesh_state_interface", ":tpu_mesh_state_interface",
":tpu_program_group_interface", ":tpu_program_group_interface",
":tpu_util", ":tpu_util",
":tpu_util_c_api_hdrs",
":tpu_util_hdrs", ":tpu_util_hdrs",
"//tensorflow/compiler/jit:flags", "//tensorflow/compiler/jit:flags",
"//tensorflow/compiler/jit:shape_inference", "//tensorflow/compiler/jit:shape_inference",
@ -50,6 +52,7 @@ cc_library(
"//tensorflow/core/tpu:tpu_configuration", "//tensorflow/core/tpu:tpu_configuration",
"//tensorflow/core/tpu:tpu_defs", "//tensorflow/core/tpu:tpu_defs",
"//tensorflow/stream_executor/tpu:tpu_platform_interface", "//tensorflow/stream_executor/tpu:tpu_platform_interface",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
], ],
alwayslink = 1, alwayslink = 1,

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <string> #include <string>
#include "absl/strings/string_view.h"
#include "tensorflow/compiler/jit/flags.h" #include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/tf2xla/tf2xla_util.h" #include "tensorflow/compiler/tf2xla/tf2xla_util.h"
#include "tensorflow/compiler/xla/client/client_library.h" #include "tensorflow/compiler/xla/client/client_library.h"
@ -28,8 +29,10 @@ limitations under the License.
#include "tensorflow/core/lib/core/errors.h" #include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h" #include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h" #include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h" #include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
#include "tensorflow/core/tpu/kernels/tpu_util.h" #include "tensorflow/core/tpu/kernels/tpu_util.h"
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
#include "tensorflow/core/tpu/tpu_configuration.h" #include "tensorflow/core/tpu/tpu_configuration.h"
#include "tensorflow/core/tpu/tpu_defs.h" #include "tensorflow/core/tpu/tpu_defs.h"
@ -518,5 +521,41 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
return Status::OK(); return Status::OK();
} }
void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
CancellationToken token =
ctx->cancellation_manager()->get_cancellation_token();
const bool already_cancelled =
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
if (TpuCompile_ShouldTpuCompileOpIgnoreCancellation()) {
return;
}
// Sleep and exit in another thread so the cancellation manager can
// continue running callbacks.
ctx->env()->SchedClosure([ctx, done]() { ExitCountdown(ctx, done); });
});
// If the RPC was cancelled before we registered the cancellation callback,
// don't compile the TPU program.
OP_REQUIRES(ctx, !already_cancelled,
errors::Cancelled("RPC cancelled, not compiling TPU program"));
// We only want to abort the process if a cancellation actually occurs during
// compilation; we must deregister the callback in the success case. It
// doesn't hurt to also deregister the callback in the failure case; the
// CancellationManager ensures that already-registered callbacks will be run
// once cancellation has started.
auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
ctx->cancellation_manager()->DeregisterCallback(token);
done->store(true);
});
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
}
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow

View File

@ -53,7 +53,8 @@ class TpuCompileOpKernelCommon {
virtual ~TpuCompileOpKernelCommon() = default; virtual ~TpuCompileOpKernelCommon() = default;
virtual void Compute(OpKernelContext* ctx) = 0; void Compute(OpKernelContext* ctx);
virtual Status ComputeInternal(OpKernelContext* ctx) = 0;
// Computes shapes for each argument. Uses both the static shape from the // Computes shapes for each argument. Uses both the static shape from the
// metadata, and the dynamic shapes where the static shape is not // metadata, and the dynamic shapes where the static shape is not

View File

@ -95,6 +95,5 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
} }
return Status::OK(); return Status::OK();
} }
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow

View File

@ -68,7 +68,6 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
// A callback called on exit. // A callback called on exit.
void LogAndExit(int code); void LogAndExit(int code);
} // namespace tpu } // namespace tpu
} // namespace tensorflow } // namespace tensorflow

View File

@ -31,6 +31,12 @@ void TpuCompile_ToTpuShapeRepresentation(
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape, bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
SE_Status* status); SE_Status* status);
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
// when cancellation is requested for an XLA compile op. Some tests require this
// behavior to be disabled, and we test for this condition with the following
// flag function.
bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
} // extern "C" } // extern "C"
struct TfTpu_UtilApiFn { struct TfTpu_UtilApiFn {

View File

@ -26,6 +26,7 @@ import (
"encoding/binary" "encoding/binary"
"fmt" "fmt"
"io" "io"
"math/bits"
"reflect" "reflect"
"runtime" "runtime"
"unsafe" "unsafe"
@ -80,7 +81,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
if dataType == String { if dataType == String {
// TF_STRING tensors are encoded as an array of 8-byte offsets // TF_STRING tensors are encoded as an array of 8-byte offsets
// followed by string data. See c_api.h. // followed by string data. See c_api.h.
nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value) nbytes = uintptr(nflattened*8 + int64(byteSizeOfEncodedStrings(val)))
} }
var shapePtr *C.int64_t var shapePtr *C.int64_t
if len(shape) > 0 { if len(shape) > 0 {
@ -94,9 +95,22 @@ func NewTensor(value interface{}) (*Tensor, error) {
raw := tensorData(t.c) raw := tensorData(t.c)
buf := bytes.NewBuffer(raw[:0:len(raw)]) buf := bytes.NewBuffer(raw[:0:len(raw)])
if dataType != String { if dataType != String {
if err := encodeTensor(buf, val, shape); err != nil { if isAllArray(val.Type()) {
// We have arrays all the way down, or just primitive types. We can
// just copy the memory in as it is all contiguous.
if err := copyPtr(buf, unpackEFace(value).data, int(val.Type().Size())); err != nil {
return nil, err return nil, err
} }
} else {
// When there are slices involved the memory for each leaf slice may
// not be contiguous with the others or in the order we might
// expect, so we need to work our way down to each slice of
// primitives and copy them individually
if err := encodeTensorWithSlices(buf, val, shape); err != nil {
return nil, err
}
}
if uintptr(buf.Len()) != nbytes { if uintptr(buf.Len()) != nbytes {
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len()) return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
} }
@ -112,6 +126,43 @@ func NewTensor(value interface{}) (*Tensor, error) {
return t, nil return t, nil
} }
// isAllArray returns true if type is a primitive type or an array of primitive
// types or an array of ... etc.. When this is true the data we want is
// contiguous in RAM.
func isAllArray(typ reflect.Type) bool {
switch typ.Kind() {
case reflect.Slice:
return false
case reflect.Array:
return isAllArray(typ.Elem())
default:
// We know the type is slices/arrays of slices/arrays of primitive types.
return true
}
}
// eface defines what an interface type actually is: a pointer to type
// information about the encapsulated type and a pointer to the encapsulated
// value.
type eface struct {
rtype unsafe.Pointer
data unsafe.Pointer
}
// unpackEFace gives us an effient way to get us a pointer to the value carried
// in an interface. If you wrap a pointer type in an interface then the pointer
// is directly stored in the interface struct. If you wrap a value type in an
// interface then the compiler copies the value into a newly allocated piece of
// memory and stores a pointer to that memory in the interface. So we're
// guaranteed to get a pointer. Go reflection doesn't expose the pointer to
// value types straightforwardly as it doesn't want you to think you have a
// reference to the original value. But we just want a pointer to make it
// efficient to read the value, so cheating like this should be safe and
// reasonable.
func unpackEFace(obj interface{}) *eface {
return (*eface)(unsafe.Pointer(&obj))
}
// ReadTensor constructs a Tensor with the provided type and shape from the // ReadTensor constructs a Tensor with the provided type and shape from the
// serialized tensor contents in r. // serialized tensor contents in r.
// //
@ -168,21 +219,152 @@ func (t *Tensor) Shape() []int64 { return t.shape }
// Tensor(int64, 0): int64 // Tensor(int64, 0): int64
// Tensor(float64, 3): [][][]float64 // Tensor(float64, 3): [][][]float64
func (t *Tensor) Value() interface{} { func (t *Tensor) Value() interface{} {
typ := typeOf(t.DataType(), t.Shape())
val := reflect.New(typ)
raw := tensorData(t.c) raw := tensorData(t.c)
if t.DataType() != String { shape := t.Shape()
if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil { dt := t.DataType()
panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err)) return decodeTensor(raw, shape, dt).Interface()
} }
func decodeTensor(raw []byte, shape []int64, dt DataType) reflect.Value {
// Create a 1-dimensional slice of the base large enough for the data and
// copy the data in.
n := int(numElements(shape))
var (
slice reflect.Value
typ reflect.Type
)
if dt == String {
strs, err := decodeOneDimString(raw, n)
if err != nil {
panic(bug("unable to decode string with shape %v: %v", shape, err))
}
slice = reflect.ValueOf(strs)
typ = slice.Type()
} else { } else {
nflattened := numElements(t.Shape()) typ = typeForDataType(dt)
d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()} l := n * int(typ.Size())
if err := d.decode(val, t.Shape()); err != nil { typ = reflect.SliceOf(typ)
panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err)) slice = reflect.MakeSlice(typ, n, n)
baseBytes := *(*[]byte)(unsafe.Pointer(&sliceHeader{
Data: unsafe.Pointer(slice.Pointer()),
Len: l,
Cap: l,
}))
copy(baseBytes, raw)
} }
// Now we have the data in place in the base slice we can add the
// dimensions. We want to walk backwards through the shape. If the shape is
// length 1 or 0 then we're already done.
if len(shape) == 0 {
return slice.Index(0)
} }
return reflect.Indirect(val).Interface() if len(shape) == 1 {
return slice
}
// We have a special case if the tensor has no data. Our backing slice is
// empty, but we still want to create slices following the shape. In this
// case only the final part of the shape will be 0 and we want to recalculate
// n at this point ignoring that 0.
// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
// want 6 zero length slices to group as follows.
// {{} {}} {{} {}} {{} {}}
if n == 0 {
n = int(numElements(shape[:len(shape)-1]))
}
for i := len(shape) - 2; i >= 0; i-- {
underlyingSize := typ.Elem().Size()
typ = reflect.SliceOf(typ)
subsliceLen := int(shape[i+1])
if subsliceLen != 0 {
n = n / subsliceLen
}
// Just using reflection it is difficult to avoid unnecessary
// allocations while setting up the sub-slices as the Slice function on
// a slice Value allocates. So we end up doing pointer arithmetic!
// Pointer() on a slice gives us access to the data backing the slice.
// We insert slice headers directly into this data.
data := unsafe.Pointer(slice.Pointer())
nextSlice := reflect.MakeSlice(typ, n, n)
for j := 0; j < n; j++ {
// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
setSliceInSlice(nextSlice, j, sliceHeader{
Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
Len: subsliceLen,
Cap: subsliceLen,
})
}
slice = nextSlice
}
return slice
}
// setSliceInSlice sets slice[index] = content.
func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
const sliceSize = unsafe.Sizeof(sliceHeader{})
// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
}
// decodeOneDimString decodes a string tensor into a one-dimensional []string.
func decodeOneDimString(raw []byte, nStrings int) ([]string, error) {
// Start by making an array of all the strings
strs := make([]string, nStrings)
// The first nStrings * 8 bytes of raw are offsets into the second half of
// the raw data. This second half is where the strings are encoded.
offsets := (*(*[]int64)(unsafe.Pointer(&raw)))[:nStrings]
// Reset raw after the offsets. Now the offsets will work relative to raw
raw = raw[nStrings*8:]
// Next we work out the final length of the string data so we can copy the
// good data out of raw (which is owned by the C tensor and won't be safe
// to access if the tensor is freed)
r := bytes.NewReader(raw)
var totalLength int
for _, offset := range offsets {
// At each offset we should find a varint length of a string.
// Errors here should mean the tensor is corrupt.
if _, err := r.Seek(offset, io.SeekStart); err != nil {
return nil, err
}
l, err := binary.ReadUvarint(r)
if err != nil {
return nil, err
}
totalLength += int(l)
}
// Lets allocate a big buffer to carry our string data.
stringData := make([]byte, 0, totalLength)
// Now copy the string data across into our new buffer, keeping track of the
// location of each string in the strs slice.
var cursor int
for i, offset := range offsets {
// At each offset we should find a varint length. Read it
if _, err := r.Seek(offset, io.SeekStart); err != nil {
return nil, err
}
l, err := binary.ReadUvarint(r)
if err != nil {
return nil, err
}
// Then copy the actual string into our large buffer
target := stringData[cursor : cursor+int(l)]
if _, err := r.Read(target); err != nil {
return nil, err
}
// Track where this string data is.
strs[i] = *(*string)(unsafe.Pointer(&target))
cursor += int(l)
}
// So now we have a big slice of strings
return strs, nil
} }
// WriteContentsTo writes the serialized contents of t to w. // WriteContentsTo writes the serialized contents of t to w.
@ -261,18 +443,18 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
return shape, dt, fmt.Errorf("unsupported type %v", typ) return shape, dt, fmt.Errorf("unsupported type %v", typ)
} }
// typeOf converts from a DataType and Shape to the equivalent Go type. func typeForDataType(dt DataType) reflect.Type {
func typeOf(dt DataType, shape []int64) reflect.Type {
var ret reflect.Type
for _, t := range types { for _, t := range types {
if dt == DataType(t.dataType) { if dt == DataType(t.dataType) {
ret = t.typ return t.typ
break
} }
} }
if ret == nil {
panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt)) panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
} }
// typeOf converts from a DataType and Shape to the equivalent Go type.
func typeOf(dt DataType, shape []int64) reflect.Type {
ret := typeForDataType(dt)
for range shape { for range shape {
ret = reflect.SliceOf(ret) ret = reflect.SliceOf(ret)
} }
@ -289,37 +471,37 @@ func numElements(shape []int64) int64 {
// byteSizeOfEncodedStrings returns the size of the encoded strings in val. // byteSizeOfEncodedStrings returns the size of the encoded strings in val.
// val MUST be a string, or a container (array/slice etc.) of strings. // val MUST be a string, or a container (array/slice etc.) of strings.
func byteSizeOfEncodedStrings(val interface{}) uintptr { // Tensorflow encodes strings as the varint encoded length followed by the
if s, ok := val.(string); ok { // string bytes. We could call into the C library to do this but cgo has a heavy
return uintptr(C.TF_StringEncodedSize(C.size_t(len(s)))) // overhead. So we just do that calculation in Go
func byteSizeOfEncodedStrings(val reflect.Value) int {
if val.Kind() == reflect.String {
return sizeVarUint(uint64(val.Len())) + val.Len()
}
if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
panic(fmt.Sprintf("unexpected type %s", val.Type()))
} }
// Otherwise must be an array or slice. // Otherwise must be an array or slice.
var size uintptr var size int
v := reflect.ValueOf(val) for i := 0; i < val.Len(); i++ {
for i := 0; i < v.Len(); i++ { size += byteSizeOfEncodedStrings(val.Index(i))
size += byteSizeOfEncodedStrings(v.Index(i).Interface())
} }
return size return size
} }
// encodeTensor writes v to the specified buffer using the format specified in // sizeVarUint determines how many bytes it would take to encode the int v as
// c_api.h. Use stringEncoder for String tensors. // an unsigned varint
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error { func sizeVarUint(v uint64) int {
switch v.Kind() { if v < 0x80 {
case reflect.Bool: return 1
b := byte(0)
if v.Bool() {
b = 1
} }
if err := w.WriteByte(b); err != nil { bits := bits.Len64(v)
return err return (bits + 6) / 7
}
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
return err
} }
case reflect.Array, reflect.Slice: // encodeTensorWithSlices writes v to the specified buffer using the format specified in
// c_api.h. Use stringEncoder for String tensors.
func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) error {
// If current dimension is a slice, verify that it has the expected size // If current dimension is a slice, verify that it has the expected size
// Go's type system makes that guarantee for arrays. // Go's type system makes that guarantee for arrays.
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
@ -327,71 +509,55 @@ func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
if v.Len() != expected { if v.Len() != expected {
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected) return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
} }
} else if v.Kind() != reflect.Array {
return fmt.Errorf("unsupported type %v", v.Type())
} }
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice // Once we have just a single dimension we can just copy the data
if len(shape) == 1 && v.Len() > 0 { if len(shape) == 1 && v.Len() > 0 {
switch v.Index(0).Kind() { elt := v.Index(0)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128: if !elt.CanAddr() {
return binary.Write(w, nativeEndian, v.Interface()) panic("cannot take address")
} }
ptr := unsafe.Pointer(elt.Addr().Pointer())
return copyPtr(w, ptr, v.Len()*int(elt.Type().Size()))
} }
subShape := shape[1:] subShape := shape[1:]
for i := 0; i < v.Len(); i++ { for i := 0; i < v.Len(); i++ {
err := encodeTensor(w, v.Index(i), subShape) err := encodeTensorWithSlices(w, v.Index(i), subShape)
if err != nil { if err != nil {
return err return err
} }
} }
default:
return fmt.Errorf("unsupported type %v", v.Type())
}
return nil return nil
} }
// decodeTensor decodes the Tensor from the buffer to ptr using the format // It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
// specified in c_api.h. Use stringDecoder for String tensors. // this is not inspected by the garbage collector
func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error { type sliceHeader struct {
switch typ.Kind() { Data unsafe.Pointer
case reflect.Bool: Len int
b, err := r.ReadByte() Cap int
if err != nil { }
// copyPtr copies the backing data for a slice or array directly into w. Note
// we don't need to worry about byte ordering because we want the natural byte
// order for the machine we're running on.
func copyPtr(w *bytes.Buffer, ptr unsafe.Pointer, l int) error {
// Convert our slice header into a []byte so we can call w.Write
b := *(*[]byte)(unsafe.Pointer(&sliceHeader{
Data: ptr,
Len: l,
Cap: l,
}))
_, err := w.Write(b)
return err return err
} }
ptr.Elem().SetBool(b == 1)
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
return err
}
case reflect.Slice:
val := reflect.Indirect(ptr)
val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
if len(shape) == 1 && val.Len() > 0 {
switch val.Index(0).Kind() {
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
return binary.Read(r, nativeEndian, val.Interface())
}
}
for i := 0; i < val.Len(); i++ {
if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
return err
}
}
default:
return fmt.Errorf("unsupported type %v", typ)
}
return nil
}
type stringEncoder struct { type stringEncoder struct {
offsets io.Writer offsets *bytes.Buffer
data []byte data []byte
offset uint64 offset uint64
status *status status *status
@ -399,19 +565,18 @@ type stringEncoder struct {
func (e *stringEncoder) encode(v reflect.Value, shape []int64) error { func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
if v.Kind() == reflect.String { if v.Kind() == reflect.String {
if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil { if err := copyPtr(e.offsets, unsafe.Pointer(&e.offset), int(unsafe.Sizeof(e.offset))); err != nil {
return err return err
} }
var ( // A string is encoded as the varint length followed by the string bytes.
s = v.Interface().(string) // We do this in Go to avoid the considerable overhead of a cgo call into
src = C.CString(s) // the tensorflow library
srcLen = C.size_t(len(s)) s := v.String()
dst = (*C.char)(unsafe.Pointer(&e.data[e.offset])) n := binary.PutUvarint(e.data[e.offset:], uint64(len(s)))
dstLen = C.size_t(uint64(len(e.data)) - e.offset) e.offset += uint64(n)
) n = copy(e.data[e.offset:], s)
e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c)) e.offset += uint64(n)
C.free(unsafe.Pointer(src)) return nil
return e.status.Err()
} }
if v.Kind() == reflect.Slice { if v.Kind() == reflect.Slice {
@ -430,45 +595,6 @@ func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
return nil return nil
} }
type stringDecoder struct {
offsets io.Reader
data []byte
status *status
}
func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error {
if len(shape) == 0 {
var offset uint64
if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil {
return err
}
var (
src = (*C.char)(unsafe.Pointer(&d.data[offset]))
srcLen = C.size_t(len(d.data)) - C.size_t(offset)
dst *C.char
dstLen C.size_t
)
if offset > uint64(len(d.data)) {
return fmt.Errorf("invalid offsets in String Tensor")
}
C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c)
if err := d.status.Err(); err != nil {
return err
}
s := ptr.Interface().(*string)
*s = C.GoStringN(dst, C.int(dstLen))
return nil
}
val := reflect.Indirect(ptr)
val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0])))
for i := 0; i < val.Len(); i++ {
if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil {
return err
}
}
return nil
}
func bug(format string, args ...interface{}) error { func bug(format string, args ...interface{}) error {
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...)) return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
} }
@ -489,22 +615,3 @@ func isTensorSerializable(dataType DataType) error {
return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType) return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
} }
} }
// nativeEndian is the byte order for the local platform. Used to send back and
// forth Tensors with the C API. We test for endianness at runtime because
// some architectures can be booted into different endian modes.
var nativeEndian binary.ByteOrder
func init() {
buf := [2]byte{}
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
switch buf {
case [2]byte{0xCD, 0xAB}:
nativeEndian = binary.LittleEndian
case [2]byte{0xAB, 0xCD}:
nativeEndian = binary.BigEndian
default:
panic("Could not determine native endianness.")
}
}

View File

@ -18,6 +18,7 @@ package tensorflow
import ( import (
"bytes" "bytes"
"fmt"
"io" "io"
"reflect" "reflect"
"testing" "testing"
@ -276,6 +277,7 @@ func TestReadTensorReadAll(t *testing.T) {
} }
func benchmarkNewTensor(b *testing.B, v interface{}) { func benchmarkNewTensor(b *testing.B, v interface{}) {
b.ReportAllocs()
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
if t, err := NewTensor(v); err != nil || t == nil { if t, err := NewTensor(v); err != nil || t == nil {
b.Fatalf("(%v, %v)", t, err) b.Fatalf("(%v, %v)", t, err)
@ -283,32 +285,68 @@ func benchmarkNewTensor(b *testing.B, v interface{}) {
} }
} }
func BenchmarkNewTensor(b *testing.B) { func benchmarkValueTensor(b *testing.B, v interface{}) {
var ( t, err := NewTensor(v)
// Some sample sizes from the Inception image labeling model. if err != nil {
// Where input tensors correspond to a 224x224 RGB image b.Fatalf("(%v, %v)", t, err)
// flattened into a vector.
vector [224 * 224 * 3]int32
)
b.Run("[150528]", func(b *testing.B) { benchmarkNewTensor(b, vector) })
} }
b.ReportAllocs()
b.ResetTimer()
func benchmarkDecodeTensor(b *testing.B, t *Tensor) {
for i := 0; i < b.N; i++ { for i := 0; i < b.N; i++ {
_ = t.Value() _ = t.Value()
} }
} }
func BenchmarkDecodeTensor(b *testing.B) { func BenchmarkTensor(b *testing.B) {
var (
// Some sample sizes from the Inception image labeling model. // Some sample sizes from the Inception image labeling model.
// Where input tensors correspond to a 224x224 RGB image // Where input tensors correspond to a 224x224 RGB image
// flattened into a vector. // flattened into a vector.
vector [224 * 224 * 3]int32 var vector [224 * 224 * 3]int32
) var arrays [100][100][100]int32
t, err := NewTensor(vector)
if err != nil { l3 := make([][][]float32, 100)
b.Fatalf("(%v, %v)", t, err) l2 := make([][]float32, 100*100)
l1 := make([]float32, 100*100*100)
for i := range l2 {
l2[i] = l1[i*100 : (i+1)*100]
} }
b.Run("[150528]", func(b *testing.B) { benchmarkDecodeTensor(b, t) }) for i := range l3 {
l3[i] = l2[i*100 : (i+1)*100]
}
s1 := make([]string, 100*100*100)
s2 := make([][]string, 100*100)
s3 := make([][][]string, 100)
for i := range s1 {
s1[i] = "cheesit"
}
for i := range s2 {
s2[i] = s1[i*100 : (i+1)*100]
}
for i := range s3 {
s3[i] = s2[i*100 : (i+1)*100]
}
tests := []interface{}{
vector,
arrays,
l1,
l2,
l3,
s1,
s2,
s3,
}
b.Run("New", func(b *testing.B) {
for _, test := range tests {
b.Run(fmt.Sprintf("%T", test), func(b *testing.B) { benchmarkNewTensor(b, test) })
}
})
b.Run("Value", func(b *testing.B) {
for _, test := range tests {
b.Run(fmt.Sprintf("%T", test), func(b *testing.B) { benchmarkValueTensor(b, test) })
}
})
} }

View File

@ -440,76 +440,6 @@ typedef struct TfLiteTensor {
// `dims_signature` contains [1, -1, -1, 3]). // `dims_signature` contains [1, -1, -1, 3]).
const TfLiteIntArray* dims_signature; const TfLiteIntArray* dims_signature;
} TfLiteTensor; } TfLiteTensor;
#else
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// NOTE: This flag is opt-in only at compile time.
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// A structure representing an instance of a node. // A structure representing an instance of a node.
// This structure only exhibits the inputs, outputs and user defined data, not // This structure only exhibits the inputs, outputs and user defined data, not
@ -547,6 +477,112 @@ typedef struct TfLiteNode {
// WARNING: This is an experimental interface that is subject to change. // WARNING: This is an experimental interface that is subject to change.
struct TfLiteDelegate* delegate; struct TfLiteDelegate* delegate;
} TfLiteNode; } TfLiteNode;
#else
// NOTE: This flag is opt-in only at compile time.
//
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// This struct does not use:
// - allocation
// - buffer_handle
// - data_is_stale
// - delegate
// - dims_signature
// - name
// - sparsity
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
// Specific reduced TfLiteNode struct for TF Micro runtime. This struct contains
// only the minimum fields required to represent a node.
//
// This struct does not use:
// - delegate
// - intermediates
// - temporaries
typedef struct TfLiteNode {
// Inputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* inputs;
// Outputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* outputs;
// Opaque data provided by the node implementer through `Registration.init`.
void* user_data;
// Opaque data provided to the node if the node is a builtin. This is usually
// a structure defined in builtin_op_data.h
void* builtin_data;
// Custom initial data. This is the opaque data provided in the flatbuffer.
// WARNING: This is an experimental interface that is subject to change.
const void* custom_initial_data;
int custom_initial_data_size;
} TfLiteNode;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// WARNING: This is an experimental interface that is subject to change. // WARNING: This is an experimental interface that is subject to change.
// //

View File

@ -263,6 +263,43 @@ TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator,
return kTfLiteOk; return kTfLiteOk;
} }
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseCeil(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseConcatenation(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLiteConcatenationParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLiteConcatenationParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const ConcatenationOptions* schema_params =
op->builtin_options_as_ConcatenationOptions();
if (schema_params != nullptr) {
params->activation =
ConvertActivation(schema_params->fused_activation_function());
params->axis = schema_params->axis();
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator, TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) { BuiltinDataAllocator* allocator, void** builtin_data) {
@ -295,6 +332,14 @@ TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
return kTfLiteOk; return kTfLiteOk;
} }
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseCos(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator, TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, BuiltinDataAllocator* allocator,
@ -339,6 +384,22 @@ TfLiteStatus ParseDequantize(const Operator*, BuiltinOperator, ErrorReporter*,
return kTfLiteOk; return kTfLiteOk;
} }
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseEqual(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseFloor(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator, TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, BuiltinDataAllocator* allocator,
@ -385,6 +446,53 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
return kTfLiteOk; return kTfLiteOk;
} }
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseGreater(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
// We have this parse function instead of directly returning kTfLiteOk from the
// switch-case in ParseOpData because this function is used as part of the
// selective registration for the OpResolver implementation in micro.
TfLiteStatus ParseGreaterEqual(const Operator*, BuiltinOperator, ErrorReporter*,
BuiltinDataAllocator*, void**) {
return kTfLiteOk;
}
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data) {
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
SafeBuiltinDataAllocator safe_allocator(allocator);
std::unique_ptr<TfLitePoolParams,
SafeBuiltinDataAllocator::BuiltinDataDeleter>
params = safe_allocator.Allocate<TfLitePoolParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
const Pool2DOptions* schema_params = op->builtin_options_as_Pool2DOptions();
if (schema_params != nullptr) {
params->padding = ConvertPadding(schema_params->padding());
params->stride_width = schema_params->stride_w();
params->stride_height = schema_params->stride_h();
params->filter_width = schema_params->filter_width();
params->filter_height = schema_params->filter_height();
params->activation =
ConvertActivation(schema_params->fused_activation_function());
} else {
// TODO(b/157480169): We should either return kTfLiteError or fill in some
// reasonable defaults in the params struct. We are not doing so until we
// better undertand the ramifications of changing the legacy behavior.
}
*builtin_data = params.release();
return kTfLiteOk;
}
TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator, TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, BuiltinDataAllocator* allocator,
@ -532,6 +640,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
return ParseArgMin(op, op_type, error_reporter, allocator, builtin_data); return ParseArgMin(op, op_type, error_reporter, allocator, builtin_data);
} }
case BuiltinOperator_AVERAGE_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_CEIL: {
return ParseCeil(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_CONCATENATION: {
return ParseConcatenation(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_CONV_2D: { case BuiltinOperator_CONV_2D: {
return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data); return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data);
} }
@ -546,11 +667,32 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
builtin_data); builtin_data);
} }
case BuiltinOperator_FLOOR: {
return ParseFloor(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_FULLY_CONNECTED: { case BuiltinOperator_FULLY_CONNECTED: {
return ParseFullyConnected(op, op_type, error_reporter, allocator, return ParseFullyConnected(op, op_type, error_reporter, allocator,
builtin_data); builtin_data);
} }
case BuiltinOperator_GREATER: {
return ParseGreater(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_GREATER_EQUAL: {
return ParseGreaterEqual(op, op_type, error_reporter, allocator,
builtin_data);
}
case BuiltinOperator_MAX_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_L2_POOL_2D: {
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
}
case BuiltinOperator_QUANTIZE: { case BuiltinOperator_QUANTIZE: {
return ParseQuantize(op, op_type, error_reporter, allocator, return ParseQuantize(op, op_type, error_reporter, allocator,
builtin_data); builtin_data);
@ -592,23 +734,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
*builtin_data = params.release(); *builtin_data = params.release();
return kTfLiteOk; return kTfLiteOk;
} }
case BuiltinOperator_AVERAGE_POOL_2D:
case BuiltinOperator_MAX_POOL_2D:
case BuiltinOperator_L2_POOL_2D: {
auto params = safe_allocator.Allocate<TfLitePoolParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
params->padding = ConvertPadding(pool_params->padding());
params->stride_width = pool_params->stride_w();
params->stride_height = pool_params->stride_h();
params->filter_width = pool_params->filter_width();
params->filter_height = pool_params->filter_height();
params->activation =
ConvertActivation(pool_params->fused_activation_function());
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: { case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
auto params = safe_allocator.Allocate<TfLiteSequenceRNNParams>(); auto params = safe_allocator.Allocate<TfLiteSequenceRNNParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr); TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -666,18 +791,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_HASHTABLE_LOOKUP: case BuiltinOperator_HASHTABLE_LOOKUP:
// no-op. // no-op.
return kTfLiteOk; return kTfLiteOk;
case BuiltinOperator_CONCATENATION: {
auto params = safe_allocator.Allocate<TfLiteConcatenationParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr);
if (const auto* concatenation_params =
op->builtin_options_as_ConcatenationOptions()) {
params->activation = ConvertActivation(
concatenation_params->fused_activation_function());
params->axis = concatenation_params->axis();
}
*builtin_data = params.release();
return kTfLiteOk;
}
case BuiltinOperator_MUL: { case BuiltinOperator_MUL: {
auto params = safe_allocator.Allocate<TfLiteMulParams>(); auto params = safe_allocator.Allocate<TfLiteMulParams>();
TF_LITE_ENSURE(error_reporter, params != nullptr); TF_LITE_ENSURE(error_reporter, params != nullptr);
@ -1102,10 +1215,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
case BuiltinOperator_EQUAL: case BuiltinOperator_EQUAL:
case BuiltinOperator_EXP: case BuiltinOperator_EXP:
case BuiltinOperator_EXPAND_DIMS: case BuiltinOperator_EXPAND_DIMS:
case BuiltinOperator_CEIL:
case BuiltinOperator_FLOOR:
case BuiltinOperator_GREATER:
case BuiltinOperator_GREATER_EQUAL:
case BuiltinOperator_HARD_SWISH: case BuiltinOperator_HARD_SWISH:
case BuiltinOperator_LESS: case BuiltinOperator_LESS:
case BuiltinOperator_LESS_EQUAL: case BuiltinOperator_LESS_EQUAL:

View File

@ -91,10 +91,23 @@ TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data); BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseCeil(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseConcatenation(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type, TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data); BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseCos(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator op_type, TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, BuiltinDataAllocator* allocator,
@ -105,11 +118,32 @@ TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type,
BuiltinDataAllocator* allocator, BuiltinDataAllocator* allocator,
void** builtin_data); void** builtin_data);
TfLiteStatus ParseEqual(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseFloor(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator op_type, TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, BuiltinDataAllocator* allocator,
void** builtin_data); void** builtin_data);
TfLiteStatus ParseGreater(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseGreaterEqual(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator,
void** builtin_data);
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, void** builtin_data);
TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type, TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type,
ErrorReporter* error_reporter, ErrorReporter* error_reporter,
BuiltinDataAllocator* allocator, BuiltinDataAllocator* allocator,

View File

@ -49,6 +49,7 @@ cc_library(
":tensor_type", ":tensor_type",
":util", ":util",
"//tensorflow/lite/delegates/gpu/common:access_type", "//tensorflow/lite/delegates/gpu/common:access_type",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:status",
"//tensorflow/lite/delegates/gpu/common:types", "//tensorflow/lite/delegates/gpu/common:types",
"//tensorflow/lite/delegates/gpu/common:util", "//tensorflow/lite/delegates/gpu/common:util",
@ -84,6 +85,7 @@ cc_library(
":gpu_object", ":gpu_object",
":opencl_wrapper", ":opencl_wrapper",
":util", ":util",
"//tensorflow/lite/delegates/gpu/common:data_type",
"//tensorflow/lite/delegates/gpu/common:status", "//tensorflow/lite/delegates/gpu/common:status",
"@com_google_absl//absl/strings", "@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span", "@com_google_absl//absl/types:span",
@ -330,6 +332,7 @@ cc_library(
cc_library( cc_library(
name = "gpu_object", name = "gpu_object",
srcs = ["gpu_object.cc"],
hdrs = ["gpu_object.h"], hdrs = ["gpu_object.h"],
deps = [ deps = [
":opencl_wrapper", ":opencl_wrapper",

View File

@ -21,6 +21,7 @@ limitations under the License.
#include "absl/strings/str_split.h" #include "absl/strings/str_split.h"
#include "absl/strings/substitute.h" #include "absl/strings/substitute.h"
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h" #include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite { namespace tflite {
@ -457,20 +458,14 @@ std::string Arguments::GetListOfArgs() {
for (auto& t : buffers_) { for (auto& t : buffers_) {
const std::string type_name = const std::string type_name =
t.second.data_type == DataType::FLOAT32 ? "float" : "half"; t.second.data_type == DataType::FLOAT32 ? "float" : "half";
std::string memory_type; std::string attributes;
switch (t.second.memory_type) { for (const auto& attr : t.second.attributes) {
case MemoryType::GLOBAL: attributes += absl::StrCat(" __attribute__((", attr, "))");
memory_type = "__global";
break;
case MemoryType::CONSTANT:
memory_type = "__constant";
break;
case MemoryType::LOCAL:
memory_type = "__local";
break;
} }
AppendArgument(absl::StrCat(memory_type, " ", type_name, AppendArgument(
t.second.element_size, "* ", t.first), absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
ToCLDataType(t.second.data_type, t.second.element_size),
"* ", t.first, attributes),
&result); &result);
} }
for (auto& t : image_buffers_) { for (auto& t : image_buffers_) {

View File

@ -15,6 +15,9 @@ limitations under the License.
#include "tensorflow/lite/delegates/gpu/cl/buffer.h" #include "tensorflow/lite/delegates/gpu/cl/buffer.h"
#include <string>
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
#include "tensorflow/lite/delegates/gpu/common/status.h" #include "tensorflow/lite/delegates/gpu/common/status.h"
namespace tflite { namespace tflite {
@ -51,6 +54,7 @@ GPUResources BufferDescriptor::GetGPUResources(AccessType access_type) const {
desc.access_type = access_type; desc.access_type = access_type;
desc.element_size = element_size; desc.element_size = element_size;
desc.memory_type = memory_type; desc.memory_type = memory_type;
desc.attributes = attributes;
resources.buffers.push_back({"buffer", desc}); resources.buffers.push_back({"buffer", desc});
return resources; return resources;
} }
@ -61,7 +65,7 @@ absl::Status BufferDescriptor::PerformSelector(
if (selector == "Read") { if (selector == "Read") {
return PerformReadSelector(args, result); return PerformReadSelector(args, result);
} else if (selector == "GetPtr") { } else if (selector == "GetPtr") {
return PerformGetPtrSelector(args, result); return PerformGetPtrSelector(args, template_args, result);
} else { } else {
return absl::NotFoundError(absl::StrCat( return absl::NotFoundError(absl::StrCat(
"BufferDescriptor don't have selector with name - ", selector)); "BufferDescriptor don't have selector with name - ", selector));
@ -80,13 +84,34 @@ absl::Status BufferDescriptor::PerformReadSelector(
} }
absl::Status BufferDescriptor::PerformGetPtrSelector( absl::Status BufferDescriptor::PerformGetPtrSelector(
const std::vector<std::string>& args, std::string* result) const { const std::vector<std::string>& args,
if (!args.empty()) { const std::vector<std::string>& template_args, std::string* result) const {
return absl::NotFoundError( if (args.size() > 1) {
absl::StrCat("BufferDescriptor GetPtr require zero arguments, but ", return absl::NotFoundError(absl::StrCat(
"BufferDescriptor GetPtr require one or zero arguments, but ",
args.size(), " was passed")); args.size(), " was passed"));
} }
*result = "buffer"; if (template_args.size() > 1) {
return absl::NotFoundError(
absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
"arguments, but ",
template_args.size(), " was passed"));
}
std::string conversion;
if (template_args.size() == 1) {
const std::string type_name = ToCLDataType(element_type, element_size);
if (type_name != template_args[0]) {
conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
template_args[0], "*)&");
}
}
if (args.empty()) {
*result = absl::StrCat(conversion, "buffer");
} else if (conversion.empty()) {
*result = absl::StrCat("(buffer + ", args[0], ")");
} else {
*result = absl::StrCat(conversion, "buffer[", args[0], "]");
}
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -30,9 +30,10 @@ namespace gpu {
namespace cl { namespace cl {
struct BufferDescriptor : public GPUObjectDescriptor { struct BufferDescriptor : public GPUObjectDescriptor {
DataType element_type; // FLOAT32 or FLOAT16 DataType element_type;
int element_size; int element_size;
MemoryType memory_type = MemoryType::GLOBAL; MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
absl::Status PerformSelector(const std::string& selector, absl::Status PerformSelector(const std::string& selector,
const std::vector<std::string>& args, const std::vector<std::string>& args,
@ -42,8 +43,9 @@ struct BufferDescriptor : public GPUObjectDescriptor {
GPUResources GetGPUResources(AccessType access_type) const override; GPUResources GetGPUResources(AccessType access_type) const override;
absl::Status PerformReadSelector(const std::vector<std::string>& args, absl::Status PerformReadSelector(const std::vector<std::string>& args,
std::string* result) const; std::string* result) const;
absl::Status PerformGetPtrSelector(const std::vector<std::string>& args, absl::Status PerformGetPtrSelector(
std::string* result) const; const std::vector<std::string>& args,
const std::vector<std::string>& template_args, std::string* result) const;
}; };
// Buffer represent linear GPU data storage with arbitrary data format. // Buffer represent linear GPU data storage with arbitrary data format.

View File

@ -0,0 +1,37 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
namespace tflite {
namespace gpu {
namespace cl {
std::string MemoryTypeToCLType(MemoryType type) {
switch (type) {
case MemoryType::GLOBAL:
return "__global";
case MemoryType::CONSTANT:
return "__constant";
break;
case MemoryType::LOCAL:
return "__local";
}
return "";
}
} // namespace cl
} // namespace gpu
} // namespace tflite

View File

@ -56,11 +56,14 @@ struct GPUImageBufferDescriptor {
enum class MemoryType { GLOBAL, CONSTANT, LOCAL }; enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
std::string MemoryTypeToCLType(MemoryType type);
struct GPUBufferDescriptor { struct GPUBufferDescriptor {
DataType data_type; DataType data_type;
AccessType access_type; AccessType access_type;
int element_size; int element_size;
MemoryType memory_type = MemoryType::GLOBAL; MemoryType memory_type = MemoryType::GLOBAL;
std::vector<std::string> attributes;
cl_mem memory; cl_mem memory;
}; };

View File

@ -28,21 +28,17 @@ namespace gpu {
namespace cl { namespace cl {
namespace { namespace {
std::string GenerateConvolutionTransposedCode( std::string GenerateConvolutionTransposedCode(const OperationDef& op_def,
const OperationDef& op_def, int src_depth, int dst_channels, int src_depth, int dst_channels,
const int2& kernel_size, const CLDevice& device, const int2& kernel_size,
const std::vector<ElementwiseOperation*>& linked_operations) { Arguments* args) {
TensorCodeGenerator src_tensor( args->AddObjectRef(
"src_data", "src_tensor", AccessType::READ,
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"}, absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
op_def.src_tensors[0]); args->AddObjectRef(
TensorCodeGenerator dst_tensor( "dst_tensor", AccessType::WRITE,
"dst_data", absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
op_def.dst_tensors[0]);
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
std::string c = GetCommonDefines(op_def.precision);
const std::string channel_x = dst_channels == 1 ? "" : ".x"; const std::string channel_x = dst_channels == 1 ? "" : ".x";
const std::vector<std::string> postfix = {channel_x, ".y", ".z", ".w"}; const std::vector<std::string> postfix = {channel_x, ".y", ".z", ".w"};
const std::vector<std::string> channel = {".x", ".y", ".z", ".w"}; const std::vector<std::string> channel = {".x", ".y", ".z", ".w"};
@ -62,36 +58,33 @@ std::string GenerateConvolutionTransposedCode(
break; break;
} }
std::string c = GetCommonDefines(op_def.precision);
c += "__kernel void main_function(\n"; c += "__kernel void main_function(\n";
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n"; c += "$0) {\n";
c += " __constant FLT4* filters";
c += GetArgsDeclaration(linked_operations);
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
c += " int4 src_size, \n";
c += " int4 dst_size, \n";
c += " FLT4 bias_value \n";
c += ") {\n";
if (op_def.IsBatchSupported()) { if (op_def.IsBatchSupported()) {
c += " int linear_id = get_global_id(0);\n"; c += " int linear_id = get_global_id(0);\n";
c += " int X = linear_id / dst_size.w;\n"; c += " int X = linear_id / args.dst_tensor.Batch();\n";
c += " int B = linear_id % dst_size.w;\n"; c += " int B = linear_id % args.dst_tensor.Batch();\n";
c += " args.dst_tensor.SetBatchRef(B);\n";
c += " args.src_tensor.SetBatchRef(B);\n";
} else { } else {
c += " int X = get_global_id(0);\n"; c += " int X = get_global_id(0);\n";
} }
c += " int Y = get_global_id(1);\n"; c += " int Y = get_global_id(1);\n";
c += " if (X >= src_size.x || Y >= src_size.y) return;\n"; c += " if (X >= args.src_tensor.Width() || Y >= args.src_tensor.Height()) "
"return;\n";
c += " " + accum_type + " r[" + std::to_string(kernel_size.y) + "][" + c += " " + accum_type + " r[" + std::to_string(kernel_size.y) + "][" +
std::to_string(kernel_size.x) + "];\n"; std::to_string(kernel_size.x) + "];\n";
c += " {\n"; c += " {\n";
c += " FLT4 src = " + src_tensor.ReadWHSB("X", "Y", "0", batch_id) + ";\n"; c += " FLT4 src = args.src_tensor.Read(X, Y, 0);\n";
int index = 0; int index = 0;
for (int y = 0; y < kernel_size.y; ++y) { for (int y = 0; y < kernel_size.y; ++y) {
for (int x = 0; x < kernel_size.x; ++x) { for (int x = 0; x < kernel_size.x; ++x) {
std::string r_s = std::string r_s =
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]"; " r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
for (int d = 0; d < dst_channels; ++d) { for (int d = 0; d < dst_channels; ++d) {
c += r_s + postfix[d] + " = dot(src, filters[" + std::to_string(index) + c += r_s + postfix[d] + " = dot(src, args.weights.Read(" +
"]);\n"; std::to_string(index) + "));\n";
index++; index++;
} }
} }
@ -100,15 +93,15 @@ std::string GenerateConvolutionTransposedCode(
for (int i = 1; i < src_depth; ++i) { for (int i = 1; i < src_depth; ++i) {
c += " if (X > " + std::to_string(-i) + c += " if (X > " + std::to_string(-i) +
") { // always true, to reduce registers usage\n"; ") { // always true, to reduce registers usage\n";
c += " FLT4 src = " + c +=
src_tensor.ReadWHSB("X", "Y", std::to_string(i), batch_id) + ";\n"; " FLT4 src = args.src_tensor.Read(X, Y, " + std::to_string(i) + ");\n";
for (int y = 0; y < kernel_size.y; ++y) { for (int y = 0; y < kernel_size.y; ++y) {
for (int x = 0; x < kernel_size.x; ++x) { for (int x = 0; x < kernel_size.x; ++x) {
std::string r_s = std::string r_s =
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]"; " r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
for (int d = 0; d < dst_channels; ++d) { for (int d = 0; d < dst_channels; ++d) {
c += r_s + postfix[d] + " += dot(src, filters[" + c += r_s + postfix[d] + " += dot(src, args.weights.Read(" +
std::to_string(index) + "]);\n"; std::to_string(index) + "));\n";
index++; index++;
} }
} }
@ -121,21 +114,16 @@ std::string GenerateConvolutionTransposedCode(
for (int x = 0; x < kernel_size.x; ++x) { for (int x = 0; x < kernel_size.x; ++x) {
const std::string x_coord = "X + " + std::to_string(x); const std::string x_coord = "X + " + std::to_string(x);
const std::string y_coord = "Y + " + std::to_string(y); const std::string y_coord = "Y + " + std::to_string(y);
c += " if (" + x_coord + " < dst_size.x && " + y_coord + c += " if (" + x_coord + " < args.dst_tensor.Width() && " + y_coord +
" < dst_size.y) {\n"; " < args.dst_tensor.Height()) {\n";
c += " FLT4 result = bias_value;\n"; c += " FLT4 result = args.weights.Read(" + std::to_string(index) +
");\n";
for (int d = 0; d < dst_channels; ++d) { for (int d = 0; d < dst_channels; ++d) {
c += " result" + channel[d] + " += r[" + std::to_string(y) + "][" + c += " result" + channel[d] + " += r[" + std::to_string(y) + "][" +
std::to_string(x) + "]" + postfix[d] + ";\n"; std::to_string(x) + "]" + postfix[d] + ";\n";
} }
const std::string x_3dcoord = op_def.IsBatchSupported() c += " args.dst_tensor.Write(result, " + x_coord + ", " + y_coord +
? "(" + x_coord + ") * dst_size.w + B" ", 0);\n";
: x_coord;
const LinkingContext context{"result", x_3dcoord, y_coord, "0"};
c += PostProcess(linked_operations, context);
c += " " +
dst_tensor.WriteWHSB("result", x_coord, y_coord, "0", batch_id) +
"\n";
c += " }\n"; c += " }\n";
} }
} }
@ -150,19 +138,11 @@ ConvolutionTransposedThin::ConvolutionTransposedThin(
: GPUOperation(definition), : GPUOperation(definition),
kernel_size_(attr.weights.shape.w, attr.weights.shape.h), kernel_size_(attr.weights.shape.w, attr.weights.shape.h),
src_channels_(attr.weights.shape.i), src_channels_(attr.weights.shape.i),
dst_channels_(attr.weights.shape.o) { dst_channels_(attr.weights.shape.o) {}
float4 bias_value(0.0f);
for (int i = 0; i < attr.weights.shape.o; ++i) {
bias_value[i] = attr.bias.data[i];
}
bias_value_ = FLT4(definition_.precision, bias_value);
}
ConvolutionTransposedThin::ConvolutionTransposedThin( ConvolutionTransposedThin::ConvolutionTransposedThin(
ConvolutionTransposedThin&& operation) ConvolutionTransposedThin&& operation)
: GPUOperation(std::move(operation)), : GPUOperation(std::move(operation)),
weights_buf_(std::move(operation.weights_buf_)),
bias_value_(std::move(operation.bias_value_)),
kernel_size_(operation.kernel_size_), kernel_size_(operation.kernel_size_),
src_channels_(operation.src_channels_), src_channels_(operation.src_channels_),
dst_channels_(operation.dst_channels_), dst_channels_(operation.dst_channels_),
@ -172,8 +152,6 @@ ConvolutionTransposedThin::ConvolutionTransposedThin(
ConvolutionTransposedThin& ConvolutionTransposedThin::operator=( ConvolutionTransposedThin& ConvolutionTransposedThin::operator=(
ConvolutionTransposedThin&& operation) { ConvolutionTransposedThin&& operation) {
if (this != &operation) { if (this != &operation) {
weights_buf_ = std::move(operation.weights_buf_);
bias_value_ = std::move(operation.bias_value_);
std::swap(kernel_size_, operation.kernel_size_); std::swap(kernel_size_, operation.kernel_size_);
std::swap(src_channels_, operation.src_channels_); std::swap(src_channels_, operation.src_channels_);
std::swap(dst_channels_, operation.dst_channels_); std::swap(dst_channels_, operation.dst_channels_);
@ -186,9 +164,15 @@ ConvolutionTransposedThin& ConvolutionTransposedThin::operator=(
absl::Status ConvolutionTransposedThin::Compile( absl::Status ConvolutionTransposedThin::Compile(
const CreationContext& creation_context) { const CreationContext& creation_context) {
const auto code = GenerateConvolutionTransposedCode( std::string code = GenerateConvolutionTransposedCode(
definition_, DivideRoundUp(src_channels_, 4), dst_channels_, kernel_size_, definition_, DivideRoundUp(src_channels_, 4), dst_channels_, kernel_size_,
*creation_context.device, linked_operations_); &args_);
std::string element_wise_code;
RETURN_IF_ERROR(
MergeOperations(linked_operations_, &args_, &element_wise_code));
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
{{"dst_tensor", element_wise_code}},
&code));
std::vector<CompilerOptions> options; std::vector<CompilerOptions> options;
if (definition_.precision == CalculationsPrecision::F16 && if (definition_.precision == CalculationsPrecision::F16 &&
@ -202,15 +186,10 @@ absl::Status ConvolutionTransposedThin::Compile(
} }
absl::Status ConvolutionTransposedThin::BindArguments() { absl::Status ConvolutionTransposedThin::BindArguments() {
kernel_.ResetBindingCounter(); RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr())); RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr())); RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_)); return args_.Bind(kernel_.kernel());
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_));
return absl::OkStatus();
} }
int3 ConvolutionTransposedThin::GetGridSize() const { int3 ConvolutionTransposedThin::GetGridSize() const {
@ -248,7 +227,7 @@ absl::Status CreateConvolutionTransposedThin(
} }
*result = ConvolutionTransposedThin(definition, attr); *result = ConvolutionTransposedThin(definition, attr);
RETURN_IF_ERROR( RETURN_IF_ERROR(
result->UploadWeights(attr.weights, creation_context.context)); result->UploadData(attr.weights, attr.bias, creation_context.context));
return absl::OkStatus(); return absl::OkStatus();
} }

View File

@ -58,7 +58,8 @@ class ConvolutionTransposedThin : public GPUOperation {
ConvolutionTransposedThin(const OperationDef& definition, ConvolutionTransposedThin(const OperationDef& definition,
const ConvolutionTransposedAttributes& attr); const ConvolutionTransposedAttributes& attr);
template <DataType T> template <DataType T>
absl::Status UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights, absl::Status UploadData(const tflite::gpu::Tensor<OHWI, T>& weights,
const tflite::gpu::Tensor<Linear, T>& biases,
CLContext* context); CLContext* context);
template <DataType S, typename T> template <DataType S, typename T>
@ -68,9 +69,6 @@ class ConvolutionTransposedThin : public GPUOperation {
absl::Status BindArguments(); absl::Status BindArguments();
int3 GetGridSize() const; int3 GetGridSize() const;
Buffer weights_buf_;
FLT4 bias_value_;
int2 kernel_size_; int2 kernel_size_;
int src_channels_; int src_channels_;
int dst_channels_; int dst_channels_;
@ -80,25 +78,50 @@ class ConvolutionTransposedThin : public GPUOperation {
}; };
template <DataType T> template <DataType T>
absl::Status ConvolutionTransposedThin::UploadWeights( absl::Status ConvolutionTransposedThin::UploadData(
const tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) { const tflite::gpu::Tensor<OHWI, T>& weights,
const tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
const int src_depth = DivideRoundUp(src_channels_, 4); const int src_depth = DivideRoundUp(src_channels_, 4);
const int elements_count = const int flt4_count =
kernel_size_.x * kernel_size_.y * src_depth * 4 * dst_channels_; kernel_size_.x * kernel_size_.y * src_depth * dst_channels_;
const int float4_size = const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
definition_.precision == CalculationsPrecision::F32 ? 16 : 8;
if (definition_.GetDataType() == DataType::FLOAT32) { BufferDescriptor desc;
std::vector<float4> gpu_data(elements_count); desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
desc.element_size = 4;
desc.memory_type = MemoryType::CONSTANT;
Buffer weights_buffer;
if (f32_weights) {
std::vector<float4> gpu_data(flt4_count);
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data)); RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
return CreateReadOnlyBuffer(float4_size * elements_count, gpu_data.data(), float4 bias_value(0.0f);
context, &weights_buf_); for (int i = 0; i < weights.shape.o; ++i) {
} else { bias_value[i] = biases.data[i];
std::vector<half4> gpu_data(elements_count);
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
return CreateReadOnlyBuffer(float4_size * elements_count, gpu_data.data(),
context, &weights_buf_);
} }
gpu_data.push_back(bias_value);
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(float4) * gpu_data.size(),
gpu_data.data(), context,
&weights_buffer));
} else {
std::vector<half4> gpu_data(flt4_count);
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
half4 bias_value(0.0f);
for (int i = 0; i < weights.shape.o; ++i) {
bias_value[i] = biases.data[i];
}
gpu_data.push_back(bias_value);
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(half4) * gpu_data.size(),
gpu_data.data(), context,
&weights_buffer));
}
args_.AddObject("weights", AccessType::READ,
absl::make_unique<Buffer>(std::move(weights_buffer)),
absl::make_unique<BufferDescriptor>(desc));
return absl::OkStatus();
} }
template <DataType S, typename T> template <DataType S, typename T>

View File

@ -147,6 +147,13 @@ absl::Status TensorDescriptor::PerformSelector(
} else if (selector == "Slices") { } else if (selector == "Slices") {
*result = "slices"; *result = "slices";
return absl::OkStatus(); return absl::OkStatus();
} else if (selector == "SliceStride") {
if (IsBatchedWidth()) {
*result = "width_batched * height";
} else {
*result = "width * height";
}
return absl::OkStatus();
} else if (selector == "Channels") { } else if (selector == "Channels") {
*result = "channels"; *result = "channels";
return absl::OkStatus(); return absl::OkStatus();

View File

@ -55,9 +55,12 @@ inline void GetActivationMinMax(FusedActivationFunctionType ac,
} }
} }
inline float ActivationFunctionWithMinMax(float x, float output_activation_min, template <typename T>
float output_activation_max) { inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
return std::min(std::max(x, output_activation_min), output_activation_max); T output_activation_max) {
using std::max;
using std::min;
return min(max(x, output_activation_min), output_activation_max);
} }
// Legacy function, left for compatibility only. // Legacy function, left for compatibility only.

View File

@ -2766,37 +2766,39 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
} }
} }
inline void SubWithActivation(const ArithmeticParams& params, inline void SetActivationMinMax(const ArithmeticParams& params,
const RuntimeShape& input1_shape, int32* activation_min, int32* activation_max) {
const int32* input1_data, *activation_min = params.quantized_activation_min;
const RuntimeShape& input2_shape, *activation_max = params.quantized_activation_max;
const int32* input2_data,
const RuntimeShape& output_shape,
int32* output_data) {
ruy::profiler::ScopeLabel label("SubWithActivation/int32");
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.quantized_activation_min,
params.quantized_activation_max);
}
} }
inline void SubWithActivation(const ArithmeticParams& params, inline void SetActivationMinMax(const ArithmeticParams& params,
const RuntimeShape& input1_shape, float* activation_min, float* activation_max) {
const float* input1_data, *activation_min = params.float_activation_min;
const RuntimeShape& input2_shape, *activation_max = params.float_activation_max;
const float* input2_data, }
const RuntimeShape& output_shape,
float* output_data) { inline void SetActivationMinMax(const ArithmeticParams& params,
ruy::profiler::ScopeLabel label("SubWithActivation/float"); int64_t* activation_min,
int64_t* activation_max) {
*activation_min = params.int64_activation_min;
*activation_max = params.int64_activation_max;
}
template <typename T>
inline void SubWithActivation(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
const int flat_size = const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape); MatchingElementsSize(input1_shape, input2_shape, output_shape);
T activation_min, activation_max;
SetActivationMinMax(params, &activation_min, &activation_max);
for (int i = 0; i < flat_size; ++i) { for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax( output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min, input1_data[i] - input2_data[i], activation_min, activation_max);
params.float_activation_max);
} }
} }

View File

@ -1495,6 +1495,7 @@ inline void GatherNd(const RuntimeShape& params_shape,
} }
} }
#ifndef TF_LITE_STATIC_MEMORY
template <typename IndicesT = int32> template <typename IndicesT = int32>
inline void GatherNdString(const RuntimeShape& params_shape, inline void GatherNdString(const RuntimeShape& params_shape,
const TfLiteTensor* params_data, const TfLiteTensor* params_data,
@ -1517,6 +1518,7 @@ inline void GatherNdString(const RuntimeShape& params_shape,
} }
buffer.WriteToTensor(output_data, /*new_shape=*/nullptr); buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
} }
#endif
template <typename IndicesT, typename UpdatesT> template <typename IndicesT, typename UpdatesT>
inline void ScatterNd(const RuntimeShape& indices_shape, inline void ScatterNd(const RuntimeShape& indices_shape,

View File

@ -260,6 +260,45 @@ inline void BroadcastSubSlow(const ArithmeticParams& params,
NDOpsHelper<N>(output_desc, sub_func); NDOpsHelper<N>(output_desc, sub_func);
} }
template <int N = 5>
void BroadcastSubSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int64_t* input1_data,
const RuntimeShape& input2_shape,
const int64_t* input2_data,
const RuntimeShape& output_shape, int64_t* output_data) {
ruy::profiler::ScopeLabel label("BroadcastSubSlow/int64");
TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N);
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N);
NdArrayDesc<N> desc1;
NdArrayDesc<N> desc2;
NdArrayDesc<N> output_desc;
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
&desc2);
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc);
// In Tensorflow, the dimensions are canonically named (batch_number, row,
// col, channel), with extents (batches, height, width, depth), with the
// trailing dimension changing most rapidly (channels has the smallest stride,
// typically 1 element).
//
// In generated C code, we store arrays with the dimensions reversed. The
// first dimension has smallest stride.
//
// We name our variables by their Tensorflow convention, but generate C code
// nesting loops such that the innermost loop has the smallest stride for the
// best cache behavior.
auto sub_func = [&](int indexes[N]) {
output_data[SubscriptToIndex(output_desc, indexes)] =
ActivationFunctionWithMinMax(
input1_data[SubscriptToIndex(desc1, indexes)] -
input2_data[SubscriptToIndex(desc2, indexes)],
params.int64_activation_min, params.int64_activation_max);
};
NDOpsHelper<N>(output_desc, sub_func);
}
template <typename T, int N = 5> template <typename T, int N = 5>
void BroadcastSubSlow(const ArithmeticParams& params, void BroadcastSubSlow(const ArithmeticParams& params,
const RuntimeShape& input1_shape, const T* input1_data, const RuntimeShape& input1_shape, const T* input1_data,
@ -434,40 +473,42 @@ void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
} }
} }
inline void SubWithActivation(const ArithmeticParams& params, inline void SetActivationMinMax(const ArithmeticParams& params,
const RuntimeShape& input1_shape, int32* activation_min, int32* activation_max) {
const int32* input1_data, *activation_min = params.quantized_activation_min;
const RuntimeShape& input2_shape, *activation_max = params.quantized_activation_max;
const int32* input2_data, }
const RuntimeShape& output_shape,
int32* output_data) { inline void SetActivationMinMax(const ArithmeticParams& params,
float* activation_min, float* activation_max) {
*activation_min = params.float_activation_min;
*activation_max = params.float_activation_max;
}
inline void SetActivationMinMax(const ArithmeticParams& params,
int64_t* activation_min,
int64_t* activation_max) {
*activation_min = params.int64_activation_min;
*activation_max = params.int64_activation_max;
}
template <typename T>
inline void SubWithActivation(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
ruy::profiler::ScopeLabel label("SubWithActivation"); ruy::profiler::ScopeLabel label("SubWithActivation");
const int flat_size = const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape); MatchingElementsSize(input1_shape, input2_shape, output_shape);
T activation_min, activation_max;
SetActivationMinMax(params, &activation_min, &activation_max);
for (int i = 0; i < flat_size; ++i) { for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax( output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.quantized_activation_min, input1_data[i] - input2_data[i], activation_min, activation_max);
params.quantized_activation_max);
} }
} }
inline void SubWithActivation(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const float* input1_data,
const RuntimeShape& input2_shape,
const float* input2_data,
const RuntimeShape& output_shape,
float* output_data) {
const int flat_size =
MatchingElementsSize(input1_shape, input2_shape, output_shape);
for (int i = 0; i < flat_size; ++i) {
output_data[i] = ActivationFunctionWithMinMax(
input1_data[i] - input2_data[i], params.float_activation_min,
params.float_activation_max);
}
}
} // namespace reference_ops } // namespace reference_ops
} // namespace tflite } // namespace tflite

View File

@ -765,12 +765,17 @@ struct ArithmeticParams {
int input1_shift; int input1_shift;
int32 input2_multiplier; int32 input2_multiplier;
int input2_shift; int input2_shift;
// TODO(b/158622529): Union the following activation params.
// uint8, etc, activation params. // uint8, etc, activation params.
int32 quantized_activation_min; int32 quantized_activation_min;
int32 quantized_activation_max; int32 quantized_activation_max;
// float activation params. // float activation params.
float float_activation_min; float float_activation_min;
float float_activation_max; float float_activation_max;
// int64 activation params.
int64_t int64_activation_min;
int64_t int64_activation_max;
// Processed output dimensions. // Processed output dimensions.
// Let input "a" be the one that broadcasts in the faster-changing dimension. // Let input "a" be the one that broadcasts in the faster-changing dimension.
@ -1114,6 +1119,12 @@ inline void SetActivationParams(int32 min, int32 max, P* params) {
params->quantized_activation_max = max; params->quantized_activation_max = max;
} }
template <typename P>
inline void SetActivationParams(int64_t min, int64_t max, P* params) {
params->int64_activation_min = min;
params->int64_activation_max = max;
}
template <typename P> template <typename P>
inline void GetActivationParams(const P& params, int32* min, int32* max) { inline void GetActivationParams(const P& params, int32* min, int32* max) {
*min = params.quantized_activation_min; *min = params.quantized_activation_min;
@ -1126,6 +1137,11 @@ inline void GetActivationParams(const P& params, float* min, float* max) {
*max = params.float_activation_max; *max = params.float_activation_max;
} }
template <typename P>
inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
*min = params.int64_activation_min;
*max = params.int64_activation_max;
}
} // namespace tflite } // namespace tflite
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_ #endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_

View File

@ -44,6 +44,7 @@ inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
int index) { int index) {
return &context->tensors[node->outputs->data[index]]; return &context->tensors[node->outputs->data[index]];
} }
#ifndef TF_LITE_STATIC_MEMORY
inline TfLiteTensor* GetTemporary(TfLiteContext* context, inline TfLiteTensor* GetTemporary(TfLiteContext* context,
const TfLiteNode* node, int index) { const TfLiteNode* node, int index) {
return &context->tensors[node->temporaries->data[index]]; return &context->tensors[node->temporaries->data[index]];
@ -52,11 +53,12 @@ inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
const TfLiteNode* node, int index) { const TfLiteNode* node, int index) {
return &context->tensors[node->intermediates->data[index]]; return &context->tensors[node->intermediates->data[index]];
} }
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
inline int NumIntermediates(const TfLiteNode* node) { inline int NumIntermediates(const TfLiteNode* node) {
return node->intermediates->size; return node->intermediates->size;
} }
#endif // TF_LITE_STATIC_MEMORY
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
inline int64_t NumElements(const TfLiteIntArray* dims) { inline int64_t NumElements(const TfLiteIntArray* dims) {
int64_t count = 1; int64_t count = 1;

View File

@ -142,8 +142,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
/* min_version */ 1, /* min_version */ 1,
/* max_version */ 2); /* max_version */ 2);
AddBuiltin(BuiltinOperator_SUB, Register_SUB(), AddBuiltin(BuiltinOperator_SUB, Register_SUB(),
/* min_version */ 1, /* min_version = */ 1,
/* max_version */ 5); /* max_version = */ 5);
AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(), AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(),
/* min_version = */ 1, /* min_version = */ 1,
/* max_version = */ 4); /* max_version = */ 4);

View File

@ -340,6 +340,11 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
EvalSubImpl<kernel_type, float>(context, node, params, data, input1, EvalSubImpl<kernel_type, float>(context, node, params, data, input1,
input2, requires_broadcast, output); input2, requires_broadcast, output);
break; break;
case kTfLiteInt64:
EvalSubImpl<kernel_type, int64_t>(context, node, params, data, input1,
input2, requires_broadcast, output);
break;
default: default:
TF_LITE_KERNEL_LOG(context, "output type %s is not supported.", TF_LITE_KERNEL_LOG(context, "output type %s is not supported.",
TfLiteTypeGetName(output->type)); TfLiteTypeGetName(output->type));
@ -434,7 +439,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2); const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
TfLiteTensor* output = GetOutput(context, node, kOutputTensor); TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) { if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
output->type == kTfLiteInt64) {
EvalSub<kernel_type>(context, node, params, data, input1, input2, output); EvalSub<kernel_type>(context, node, params, data, input1, input2, output);
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 || } else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
output->type == kTfLiteInt16) { output->type == kTfLiteInt16) {

View File

@ -63,6 +63,13 @@ class IntegerSubOpModel : public BaseSubOpModel {
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); } std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
}; };
class Int64SubOpModel : public BaseSubOpModel {
public:
using BaseSubOpModel::BaseSubOpModel;
std::vector<int64_t> GetOutput() { return ExtractVector<int64_t>(output_); }
};
class QuantizedSubOpModel : public BaseSubOpModel { class QuantizedSubOpModel : public BaseSubOpModel {
public: public:
using BaseSubOpModel::BaseSubOpModel; using BaseSubOpModel::BaseSubOpModel;
@ -213,6 +220,57 @@ TEST(IntegerSubOpModel, WithBroadcast) {
} }
} }
TEST(Int64SubOpModel, NoActivation) {
Int64SubOpModel m({TensorType_INT64, {1, 2, 2, 1}},
{TensorType_INT64, {1, 2, 2, 1}}, {TensorType_INT64, {}},
ActivationFunctionType_NONE);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8});
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3}));
}
TEST(Int64SubOpModel, ActivationRELU_N1_TO_1) {
Int64SubOpModel m({TensorType_INT64, {1, 2, 2, 1}},
{TensorType_INT64, {1, 2, 2, 1}}, {TensorType_INT64, {}},
ActivationFunctionType_RELU_N1_TO_1);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8});
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 0, 1, 1}));
}
TEST(Int64SubOpModel, VariousInputShapes) {
std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
for (int i = 0; i < test_shapes.size(); ++i) {
Int64SubOpModel m({TensorType_INT64, test_shapes[i]},
{TensorType_INT64, test_shapes[i]},
{TensorType_INT64, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5, 11, 1});
m.Invoke();
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3, 0, 19}))
<< "With shape number " << i;
}
}
TEST(Int64SubOpModel, WithBroadcast) {
std::vector<std::vector<int>> test_shapes = {
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}, {1, 3, 1, 2, 1}};
for (int i = 0; i < test_shapes.size(); ++i) {
Int64SubOpModel m({TensorType_INT64, test_shapes[i]},
{TensorType_INT64, {}}, // always a scalar
{TensorType_INT64, {}}, ActivationFunctionType_NONE);
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
m.PopulateTensor<int64_t>(m.input2(), {1});
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear({-21, 1, 6, 7, 10, 19})))
<< "With shape number " << i;
}
}
template <TensorType tensor_type, typename integer_dtype> template <TensorType tensor_type, typename integer_dtype>
void QuantizedTestsNoActivation() { void QuantizedTestsNoActivation() {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0); float kQuantizedTolerance = GetTolerance(-1.0, 1.0);

View File

@ -53,19 +53,14 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
// There is 1 output at index 3 in the tensors array. // There is 1 output at index 3 in the tensors array.
int outputs_array_data[] = {1, 3}; int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no temporaries.
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(conv_params); node.builtin_data = reinterpret_cast<void*>(conv_params);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node); TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -66,19 +66,14 @@ TfLiteStatus ValidateDepthwiseConvGoldens(TfLiteTensor* tensors,
// There is 1 output at index 3 in the tensors array. // There is 1 output at index 3 in the tensors array.
int outputs_array_data[] = {1, 3}; int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no intermediates.
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node); TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -33,20 +33,23 @@ limitations under the License.
namespace { namespace {
// Create an area of memory to use for input, output, and intermediate arrays. // Create an area of memory to use for input, output, and intermediate arrays.
constexpr int tensor_arena_size = 73 * 1024; // Align arena to 16 bytes to avoid alignment warnings on certain platforms.
uint8_t tensor_arena[tensor_arena_size]; constexpr int tensor_arena_size = 21 * 1024;
alignas(16) uint8_t tensor_arena[tensor_arena_size];
// A random number generator seed to generate input values. // A random number generator seed to generate input values.
constexpr int kRandomSeed = 42; constexpr int kRandomSeed = 42;
MicroBenchmarkRunner<int16_t>& GetBenchmarkRunner() {
// NOLINTNEXTLINE // NOLINTNEXTLINE
MicroBenchmarkRunner<int16_t> runner(g_keyword_scrambled_model_data, static MicroBenchmarkRunner<int16_t> runner(
tensor_arena, tensor_arena_size, g_keyword_scrambled_model_data, tensor_arena, tensor_arena_size, 0);
kRandomSeed); return runner;
}
void KeywordRunTenIerations() { void KeywordRunTenIerations() {
// TODO(b/152644476): Add a way to run more than a single deterministic input. // TODO(b/152644476): Add a way to run more than a single deterministic input.
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
runner.RunSingleIterationRandomInput(); GetBenchmarkRunner().RunSingleIterationRandomInput();
} }
} }
@ -54,7 +57,7 @@ void KeywordRunTenIerations() {
TF_LITE_MICRO_BENCHMARKS_BEGIN TF_LITE_MICRO_BENCHMARKS_BEGIN
TF_LITE_MICRO_BENCHMARK(runner.RunSingleIterationRandomInput()); TF_LITE_MICRO_BENCHMARK(GetBenchmarkRunner().RunSingleIterationRandomInput());
TF_LITE_MICRO_BENCHMARK(KeywordRunTenIerations()); TF_LITE_MICRO_BENCHMARK(KeywordRunTenIerations());

View File

@ -27,10 +27,6 @@ limitations under the License.
#include "tensorflow/lite/schema/schema_generated.h" #include "tensorflow/lite/schema/schema_generated.h"
#include "tensorflow/lite/version.h" #include "tensorflow/lite/version.h"
// Create an area of memory to use for input, output, and intermediate arrays.
constexpr int tensor_arena_size = 73 * 1024;
uint8_t tensor_arena[tensor_arena_size];
/* /*
* Person Detection benchmark. Evaluates runtime performance of the visual * Person Detection benchmark. Evaluates runtime performance of the visual
* wakewords person detection model. This is the same model found in * wakewords person detection model. This is the same model found in
@ -40,24 +36,28 @@ uint8_t tensor_arena[tensor_arena_size];
namespace { namespace {
// Create an area of memory to use for input, output, and intermediate arrays. // Create an area of memory to use for input, output, and intermediate arrays.
// Align arena to 16 bytes to avoid alignment warnings on certain platforms.
constexpr int tensor_arena_size = 95 * 1024; constexpr int tensor_arena_size = 95 * 1024;
uint8_t tensor_arena[tensor_arena_size]; alignas(16) uint8_t tensor_arena[tensor_arena_size];
MicroBenchmarkRunner<uint8_t>& GetBenchmarkRunner() {
// NOLINTNEXTLINE // NOLINTNEXTLINE
MicroBenchmarkRunner<uint8_t> runner(g_person_detect_model_data, tensor_arena, static MicroBenchmarkRunner<uint8_t> runner(
tensor_arena_size, 0); g_person_detect_model_data, tensor_arena, tensor_arena_size, 0);
return runner;
}
void PersonDetectionTenIerationsWithPerson() { void PersonDetectionTenIerationsWithPerson() {
// TODO(b/152644476): Add a way to run more than a single deterministic input. // TODO(b/152644476): Add a way to run more than a single deterministic input.
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
runner.RunSingleIterationCustomInput(g_person_data); GetBenchmarkRunner().RunSingleIterationCustomInput(g_person_data);
} }
} }
void PersonDetectionTenIerationsWithoutPerson() { void PersonDetectionTenIerationsWithoutPerson() {
// TODO(b/152644476): Add a way to run more than a single deterministic input. // TODO(b/152644476): Add a way to run more than a single deterministic input.
for (int i = 0; i < 10; i++) { for (int i = 0; i < 10; i++) {
runner.RunSingleIterationCustomInput(g_no_person_data); GetBenchmarkRunner().RunSingleIterationCustomInput(g_no_person_data);
} }
} }
@ -65,7 +65,8 @@ void PersonDetectionTenIerationsWithoutPerson() {
TF_LITE_MICRO_BENCHMARKS_BEGIN TF_LITE_MICRO_BENCHMARKS_BEGIN
TF_LITE_MICRO_BENCHMARK(runner.RunSingleIterationCustomInput(g_person_data)); TF_LITE_MICRO_BENCHMARK(
GetBenchmarkRunner().RunSingleIterationCustomInput(g_person_data));
TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithPerson()); TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithPerson());
TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithoutPerson()); TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithoutPerson());

View File

@ -60,12 +60,10 @@ void TestReluFloat(const int* input_dims_data, const float* input_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -116,12 +114,10 @@ void TestRelu6Float(const int* input_dims_data, const float* input_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -177,12 +173,10 @@ void TestReluUint8(const int* input_dims_data, const float* input_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -242,12 +236,10 @@ void TestRelu6Uint8(const int* input_dims_data, const float* input_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -306,12 +298,10 @@ void TestReluInt8(const int* input_dims_data, const float* input_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -372,12 +362,10 @@ void TestRelu6Int8(const int* input_dims_data, const float* input_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }

View File

@ -89,18 +89,14 @@ void ValidateAddGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -46,17 +46,13 @@ void ValidateArgMinMaxGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }

View File

@ -46,17 +46,13 @@ void TestCeil(const int* input_dims_data, const float* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
for (int i = 0; i < output_dims_count; ++i) { for (int i = 0; i < output_dims_count; ++i) {

View File

@ -56,17 +56,12 @@ TfLiteNode PrepareCircularBufferInt8(const int* input_dims_data,
// There is one output - tensor 1. // There is one output - tensor 1.
const int outputs_array_data[] = {1, 1}; const int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no intermediates.
const int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -104,17 +99,12 @@ TfLiteStatus InvokeCircularBufferInt8(const int* input_dims_data,
// There is one output - tensor 1. // There is one output - tensor 1.
const int outputs_array_data[] = {1, 1}; const int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
// There are no intermediates.
const int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
node->inputs = inputs_array; node->inputs = inputs_array;
node->outputs = outputs_array; node->outputs = outputs_array;
node->temporaries = temporaries_array;
node->builtin_data = nullptr; node->builtin_data = nullptr;
node->custom_initial_data = nullptr; node->custom_initial_data = nullptr;
node->custom_initial_data_size = 0; node->custom_initial_data_size = 0;
node->delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);

View File

@ -44,18 +44,14 @@ void TestComparison(tflite::BuiltinOperator op, TfLiteTensor* tensors,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
const int outputs_array_data[] = {1, 2}; const int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
const int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -57,19 +57,18 @@ void TestConcatenateTwoInputs(std::initializer_list<int> input1_dims_data,
.activation = kTfLiteActNone // Only activation supported in this impl .activation = kTfLiteActNone // Only activation supported in this impl
}; };
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -116,19 +115,18 @@ void TestConcatenateQuantizedTwoInputs(
.activation = kTfLiteActNone // Only activation supported in this impl .activation = kTfLiteActNone // Only activation supported in this impl
}; };
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -76,18 +76,14 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3}; int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(conv_params); node.builtin_data = reinterpret_cast<void*>(conv_params);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node)); TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node));

View File

@ -73,18 +73,14 @@ TfLiteStatus ValidateDepthwiseConvGoldens(const T* expected_output_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3}; int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node)); TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node));
} }

View File

@ -48,18 +48,14 @@ void ValidateDequantizeGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -54,23 +54,18 @@ void TestElementwiseFloat(tflite::BuiltinOperator op,
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
auto inputs_array_data = {1, 0}; int inputs_array_data[] = {1, 0};
TfLiteIntArray* inputs_array = IntArrayFromInitializer(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
auto outputs_array_data = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
auto temporaries_array_data = {0};
TfLiteIntArray* temporaries_array =
IntArrayFromInitializer(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -119,19 +114,19 @@ void TestElementwiseBool(tflite::BuiltinOperator op,
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 1}); int inputs_array_data[] = {1, 0};
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -47,18 +47,13 @@ void TestFloor(const int* input_dims_data, const float* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int intermediates_array_data[] = {0};
TfLiteIntArray* temporaries_array =
IntArrayFromInts(intermediates_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
for (int i = 0; i < output_dims_count; ++i) { for (int i = 0; i < output_dims_count; ++i) {

View File

@ -72,18 +72,14 @@ TfLiteStatus TestFullyConnectedFloat(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3}; int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node)); TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node));
} }
@ -151,18 +147,14 @@ TfLiteStatus TestFullyConnectedQuantized(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3}; int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node)); TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node));

View File

@ -112,18 +112,14 @@ void TestL2Normalization(const int* input_dims_data, const T* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));

View File

@ -51,19 +51,18 @@ void TestLogicalOp(tflite::BuiltinOperator op,
const TfLiteRegistration* registration = resolver.FindOp(op); const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration); TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -62,12 +62,10 @@ void TestLogisticFloat(std::initializer_list<int> input_dims_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -122,12 +120,10 @@ void TestLogisticInt8(std::initializer_list<int> input_dims_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }

View File

@ -52,19 +52,18 @@ void TestMaxMinFloat(tflite::BuiltinOperator op,
const TfLiteRegistration* registration = resolver.FindOp(op); const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration); TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -108,19 +107,18 @@ void TestMaxMinQuantized(
const TfLiteRegistration* registration = resolver.FindOp(op); const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration); TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -162,19 +160,18 @@ void TestMaxMinQuantizedInt32(
const TfLiteRegistration* registration = resolver.FindOp(op); const TfLiteRegistration* registration = resolver.FindOp(op);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration); TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -76,7 +76,6 @@ void TestMulFloat(std::initializer_list<int> input1_dims_data,
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -148,7 +147,6 @@ void TestMulQuantized(std::initializer_list<int> input1_dims_data,
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -50,17 +50,14 @@ void TestNegFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));

View File

@ -64,19 +64,19 @@ void TestPackTwoInputsFloat(std::initializer_list<int> input1_dims_data,
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -141,19 +141,19 @@ void TestPackThreeInputsFloat(std::initializer_list<int> input1_dims_data,
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({3, 0, 1, 2});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 3}); int inputs_array_data[] = {3, 0, 1, 2};
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -214,19 +214,18 @@ void TestPackTwoInputsQuantized(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -284,19 +283,18 @@ void TestPackTwoInputsQuantized32(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -39,17 +39,13 @@ TfLiteStatus ValidatePadGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
TF_LITE_ENSURE_EQ(&context, kTfLiteOk, TF_LITE_ENSURE_EQ(&context, kTfLiteOk,
registration->prepare(&context, &node)); registration->prepare(&context, &node));
@ -76,17 +72,13 @@ TfLiteStatus ValidatePadV2Goldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 3}; int outputs_array_data[] = {1, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
// Prepare should catch dimension mismatches. // Prepare should catch dimension mismatches.
TfLiteStatus prepare_status = registration->prepare(&context, &node); TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -66,18 +66,14 @@ void TestAveragePoolingFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -138,18 +134,14 @@ void TestAveragePoolingQuantized(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -209,18 +201,14 @@ void TestMaxPoolFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -283,18 +271,14 @@ void TestMaxPoolQuantized(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }

View File

@ -58,16 +58,13 @@ void TestPreluFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -122,16 +119,13 @@ void TestPreluQuantized(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }

View File

@ -50,18 +50,14 @@ void ValidateQuantizeGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -74,12 +74,10 @@ TfLiteStatus ValidateReduceGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(params); node.builtin_data = reinterpret_cast<void*>(params);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -29,56 +29,29 @@ namespace tflite {
namespace testing { namespace testing {
namespace { namespace {
// If expected output is empty, the test is expected to fail.
template <typename T> template <typename T>
void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor, void TestReshapeImpl(TfLiteContext* context, TfLiteNode* node,
TfLiteTensor* output_tensor, TfLiteTensor* output_tensor,
std::initializer_list<T> expected_output, std::initializer_list<T> expected_output,
std::initializer_list<int> expected_dims, std::initializer_list<int> expected_dims,
bool expect_failure) { bool expect_failure) {
TfLiteContext context;
TfLiteTensor tensors[3];
TfLiteNode node;
if (shape_tensor == nullptr) {
constexpr int inputs_size = 1;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *output_tensor,
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
node.inputs = IntArrayFromInitializer({1, 0});
node.outputs = IntArrayFromInitializer({1, 1});
} else {
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *shape_tensor;
tensors[2] = *output_tensor;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
node.inputs = IntArrayFromInitializer({2, 0, 1});
node.outputs = IntArrayFromInitializer({1, 2});
}
::tflite::AllOpsResolver resolver; ::tflite::AllOpsResolver resolver;
const TfLiteRegistration* registration = const TfLiteRegistration* registration =
resolver.FindOp(tflite::BuiltinOperator_RESHAPE); resolver.FindOp(tflite::BuiltinOperator_RESHAPE);
TF_LITE_MICRO_EXPECT_NE(nullptr, registration); TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
void* user_data = nullptr; void* user_data = nullptr;
node.temporaries = nullptr; node->user_data = user_data;
node.user_data = user_data; node->builtin_data = nullptr;
node.builtin_data = nullptr; node->custom_initial_data = nullptr;
node.custom_initial_data = nullptr; node->custom_initial_data_size = 0;
node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_EQ(registration->init, nullptr); TF_LITE_MICRO_EXPECT_EQ(registration->init, nullptr);
TF_LITE_MICRO_EXPECT_EQ(registration->free, nullptr); TF_LITE_MICRO_EXPECT_EQ(registration->free, nullptr);
if (registration->prepare) { if (registration->prepare) {
// Error can happen either in Prepare or eval stage. // Error can happen either in Prepare or eval stage.
auto status = registration->prepare(&context, &node); auto status = registration->prepare(context, node);
if (status != kTfLiteOk && expect_failure) { if (status != kTfLiteOk && expect_failure) {
return; return;
} else { } else {
@ -86,11 +59,10 @@ void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
} }
} }
if (expect_failure) { if (expect_failure) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, registration->invoke(context, node));
registration->invoke(&context, &node));
return; return;
} }
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(context, node));
const int output_dims_count = ElementCount(*output_tensor->dims); const int output_dims_count = ElementCount(*output_tensor->dims);
const T* output_data = GetTensorData<T>(output_tensor); const T* output_data = GetTensorData<T>(output_tensor);
@ -105,6 +77,59 @@ void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
} }
} }
// If expected output is empty, the test is expected to fail.
template <typename T>
void TestReshapeWithShapeImpl(TfLiteTensor* input_tensor,
TfLiteTensor* shape_tensor,
TfLiteTensor* output_tensor,
std::initializer_list<T> expected_output,
std::initializer_list<int> expected_dims,
bool expect_failure) {
TfLiteContext context;
TfLiteTensor tensors[3];
TfLiteNode node;
constexpr int inputs_size = 2;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *shape_tensor;
tensors[2] = *output_tensor;
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
int inputs_data[] = {2, 0, 1};
node.inputs = IntArrayFromInts(inputs_data);
int outputs_data[] = {1, 2};
node.outputs = IntArrayFromInts(outputs_data);
TestReshapeImpl(&context, &node, output_tensor, expected_output,
expected_dims, expect_failure);
}
// If expected output is empty, the test is expected to fail.
template <typename T>
void TestReshapeWithoutShapeImpl(TfLiteTensor* input_tensor,
TfLiteTensor* output_tensor,
std::initializer_list<T> expected_output,
std::initializer_list<int> expected_dims,
bool expect_failure) {
TfLiteContext context;
TfLiteTensor tensors[3];
TfLiteNode node;
constexpr int inputs_size = 1;
constexpr int outputs_size = 1;
constexpr int tensors_size = inputs_size + outputs_size;
tensors[0] = *input_tensor;
tensors[1] = *output_tensor,
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
int inputs_data[] = {1, 0};
node.inputs = IntArrayFromInts(inputs_data);
int outputs_data[] = {1, 1};
node.outputs = IntArrayFromInts(outputs_data);
TestReshapeImpl(&context, &node, output_tensor, expected_output,
expected_dims, expect_failure);
}
template <typename T = float, TfLiteType tensor_input_type = kTfLiteFloat32> template <typename T = float, TfLiteType tensor_input_type = kTfLiteFloat32>
void TestReshape(std::initializer_list<int> input_dims_data, void TestReshape(std::initializer_list<int> input_dims_data,
std::initializer_list<T> input_data, std::initializer_list<T> input_data,
@ -122,13 +147,13 @@ void TestReshape(std::initializer_list<int> input_dims_data,
TfLiteTensor output_tensor = TfLiteTensor output_tensor =
CreateTensor<T, tensor_input_type>(output_data, output_dims); CreateTensor<T, tensor_input_type>(output_data, output_dims);
// Reshape param is passed as op's param. // Reshape param is passed as op's param.
TestReshapeImpl<T>(&input_tensor, nullptr, &output_tensor, expected_output, TestReshapeWithoutShapeImpl<T>(&input_tensor, &output_tensor, expected_output,
expected_dims, expect_failure); expected_dims, expect_failure);
// Reshape param is passed as a tensor. // Reshape param is passed as a tensor.
TfLiteIntArray* shape_dims = IntArrayFromInitializer(shape_dims_data); TfLiteIntArray* shape_dims = IntArrayFromInitializer(shape_dims_data);
auto shape_tensor = auto shape_tensor =
CreateTensor<int32_t, kTfLiteInt32>(shape_data, shape_dims); CreateTensor<int32_t, kTfLiteInt32>(shape_data, shape_dims);
TestReshapeImpl<T>(&input_tensor, &shape_tensor, &output_tensor, TestReshapeWithShapeImpl<T>(&input_tensor, &shape_tensor, &output_tensor,
expected_output, expected_dims, expect_failure); expected_output, expected_dims, expect_failure);
} }
} // namespace } // namespace
@ -192,15 +217,16 @@ TF_LITE_MICRO_TEST(InvalidShape) {
using tflite::testing::CreateFloatTensor; using tflite::testing::CreateFloatTensor;
using tflite::testing::IntArrayFromInitializer; using tflite::testing::IntArrayFromInitializer;
using tflite::testing::IntArrayFromInts; using tflite::testing::IntArrayFromInts;
TfLiteIntArray* input_dims = IntArrayFromInitializer({3, 1, 2, 2}); int input_dims_data[] = {3, 1, 2, 2};
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
auto input_data = {3.0f}; auto input_data = {3.0f};
auto input_tensor = CreateFloatTensor(input_data, input_dims); auto input_tensor = CreateFloatTensor(input_data, input_dims);
float output_data[4]; float output_data[4];
int output_dims_data[6] = {2, 2, 1, 2, 2, 1}; int output_dims_data[6] = {2, 2, 1, 2, 2, 1};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
auto output_tensor = CreateFloatTensor(output_data, output_dims); auto output_tensor = CreateFloatTensor(output_data, output_dims);
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor tflite::testing::TestReshapeWithoutShapeImpl<float>(
nullptr, // shape_tensor &input_tensor, // input_tensor
&output_tensor, // output_tensor &output_tensor, // output_tensor
{}, // expected_output {}, // expected_output
{}, // expected_dims {}, // expected_dims
@ -255,25 +281,28 @@ TF_LITE_MICRO_TEST(LegacyScalarOutput) {
using tflite::testing::CreateFloatTensor; using tflite::testing::CreateFloatTensor;
using tflite::testing::IntArrayFromInitializer; using tflite::testing::IntArrayFromInitializer;
using tflite::testing::IntArrayFromInts; using tflite::testing::IntArrayFromInts;
TfLiteIntArray* input_dims = IntArrayFromInitializer({1, 1}); int input_dims_data[] = {1, 1};
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
auto input_data = {3.0f}; auto input_data = {3.0f};
auto input_tensor = CreateFloatTensor(input_data, input_dims); auto input_tensor = CreateFloatTensor(input_data, input_dims);
float output_data[1]; float output_data[1];
int output_dims_data[2] = {1, 0}; int output_dims_data[2] = {1, 0};
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data); TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
auto output_tensor = CreateFloatTensor(output_data, output_dims); auto output_tensor = CreateFloatTensor(output_data, output_dims);
TfLiteIntArray* shape_dims = tflite::testing::IntArrayFromInitializer({1, 0}); int shape_dims_data[] = {1, 0};
TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
auto shape_tensor = auto shape_tensor =
tflite::testing::CreateTensor<int32_t, kTfLiteInt32>({0}, shape_dims); tflite::testing::CreateTensor<int32_t, kTfLiteInt32>({0}, shape_dims);
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor tflite::testing::TestReshapeWithShapeImpl<float>(
&input_tensor, // input_tensor
&shape_tensor, // shape_tensor &shape_tensor, // shape_tensor
&output_tensor, // output_tensor &output_tensor, // output_tensor
{}, // expected_output {}, // expected_output
{}, // expected_dims {}, // expected_dims
true // expect failure true // expect failure
); );
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor tflite::testing::TestReshapeWithoutShapeImpl<float>(
nullptr, // shape_tensor &input_tensor, // input_tensor
&output_tensor, // output_tensor &output_tensor, // output_tensor
{3}, // expected_output {3}, // expected_output
{}, // expected_dims {}, // expected_dims

View File

@ -78,18 +78,14 @@ void TestResizeNearestNeighbor(const int* input_dims_data, const T* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));

View File

@ -46,17 +46,13 @@ void TestRound(const int* input_dims_data, const float* input_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = nullptr; node.user_data = nullptr;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke); TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
for (int i = 0; i < output_dims_count; ++i) { for (int i = 0; i < output_dims_count; ++i) {

View File

@ -59,18 +59,14 @@ void TestSoftmaxFloat(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -124,18 +120,14 @@ void TestSoftmaxQuantized(std::initializer_list<int> input_dims_data,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -188,18 +180,14 @@ void TestSoftmaxQuantizedSigned(
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 1}; int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -76,19 +76,19 @@ void TestSplitTwoOutputsFloat(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -179,19 +179,19 @@ void TestSplitFourOutputsFloat(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({4, 2, 3, 4, 5}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {4, 2, 3, 4, 5};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -275,19 +275,19 @@ void TestSplitTwoOutputsQuantized(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -362,19 +362,19 @@ void TestSplitTwoOutputsQuantized32(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3}); int inputs_array_data[] = {2, 0, 1};
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {2, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -100,12 +100,10 @@ void TestStrideSlide(std::initializer_list<int> input_shape,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
if (expect_prepare_err) { if (expect_prepare_err) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,

View File

@ -89,18 +89,14 @@ void ValidateSubGoldens(TfLiteTensor* tensors, int tensors_size,
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
int outputs_array_data[] = {1, 2}; int outputs_array_data[] = {1, 2};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data); TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
int temporaries_array_data[] = {0};
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -201,7 +201,6 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
node.builtin_data = reinterpret_cast<void*>(&params); node.builtin_data = reinterpret_cast<void*>(&params);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node); TfLiteStatus prepare_status = registration->prepare(&context, &node);
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, prepare_status); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, prepare_status);
@ -275,7 +274,6 @@ void ValidateIntegerSVDFGoldens(const int batch_size, const int num_units,
node.builtin_data = reinterpret_cast<void*>(&params); node.builtin_data = reinterpret_cast<void*>(&params);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TfLiteStatus prepare_status = registration->prepare(&context, &node); TfLiteStatus prepare_status = registration->prepare(&context, &node);

View File

@ -62,12 +62,10 @@ void TestTanhFloat(std::initializer_list<int> input_dims_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }
@ -122,12 +120,10 @@ void TestTanhInt8(std::initializer_list<int> input_dims_data,
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = nullptr;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = nullptr; node.builtin_data = nullptr;
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
} }

View File

@ -79,19 +79,18 @@ void TestUnpackThreeOutputsFloat(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0}); int inputs_array_data[] = {1, 0};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {3, 1, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -156,19 +155,18 @@ void TestUnpackOneOutputFloat(std::initializer_list<int> input_dims_data,
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0}); int inputs_array_data[] = {1, 0};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 1}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {1, 1};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -245,19 +243,18 @@ void TestUnpackThreeOutputsQuantized(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0}); int inputs_array_data[] = {1, 0};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {3, 1, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
@ -338,19 +335,18 @@ void TestUnpackThreeOutputsQuantized32(
if (registration->init) { if (registration->init) {
user_data = registration->init(&context, nullptr, 0); user_data = registration->init(&context, nullptr, 0);
} }
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0}); int inputs_array_data[] = {1, 0};
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3}); TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0}); int outputs_array_data[] = {3, 1, 2, 3};
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
TfLiteNode node; TfLiteNode node;
node.inputs = inputs_array; node.inputs = inputs_array;
node.outputs = outputs_array; node.outputs = outputs_array;
node.temporaries = temporaries_array;
node.user_data = user_data; node.user_data = user_data;
node.builtin_data = reinterpret_cast<void*>(&builtin_data); node.builtin_data = reinterpret_cast<void*>(&builtin_data);
node.custom_initial_data = nullptr; node.custom_initial_data = nullptr;
node.custom_initial_data_size = 0; node.custom_initial_data_size = 0;
node.delegate = nullptr;
if (registration->prepare) { if (registration->prepare) {
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node)); TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));

View File

@ -45,8 +45,8 @@ constexpr int kKeywordModelNodeAndRegistrationCount = 15;
// Run this test with '--copt=-DTF_LITE_MICRO_OPTIMIZED_RUNTIME' to get // Run this test with '--copt=-DTF_LITE_MICRO_OPTIMIZED_RUNTIME' to get
// optimized memory runtime values: // optimized memory runtime values:
#ifdef TF_LITE_STATIC_MEMORY #ifdef TF_LITE_STATIC_MEMORY
constexpr int kKeywordModelTotalSize = 18448; constexpr int kKeywordModelTotalSize = 18080;
constexpr int kKeywordModelTailSize = 17776; constexpr int kKeywordModelTailSize = 17408;
#else #else
constexpr int kKeywordModelTotalSize = 21040; constexpr int kKeywordModelTotalSize = 21040;
constexpr int kKeywordModelTailSize = 20368; constexpr int kKeywordModelTailSize = 20368;
@ -65,8 +65,8 @@ constexpr int kTestConvModelNodeAndRegistrationCount = 7;
// NOTE: These values are measured on x86-64: // NOTE: These values are measured on x86-64:
// TODO(b/158651472): Consider auditing these values on non-64 bit systems. // TODO(b/158651472): Consider auditing these values on non-64 bit systems.
#ifdef TF_LITE_STATIC_MEMORY #ifdef TF_LITE_STATIC_MEMORY
constexpr int kTestConvModelTotalSize = 10960; constexpr int kTestConvModelTotalSize = 10784;
constexpr int kTestConvModelTailSize = 3216; constexpr int kTestConvModelTailSize = 3040;
#else #else
constexpr int kTestConvModelTotalSize = 11680; constexpr int kTestConvModelTotalSize = 11680;
constexpr int kTestConvModelTailSize = 3936; constexpr int kTestConvModelTailSize = 3936;

View File

@ -128,26 +128,20 @@ class MicroMutableOpResolver : public MicroOpResolver {
} }
TfLiteStatus AddAveragePool2D() { TfLiteStatus AddAveragePool2D() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D, return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
*tflite::ops::micro::Register_AVERAGE_POOL_2D(), *tflite::ops::micro::Register_AVERAGE_POOL_2D(),
ParseOpData); ParsePool);
} }
TfLiteStatus AddCeil() { TfLiteStatus AddCeil() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_CEIL, return AddBuiltin(BuiltinOperator_CEIL,
*tflite::ops::micro::Register_CEIL(), ParseOpData); *tflite::ops::micro::Register_CEIL(), ParseCeil);
} }
TfLiteStatus AddConcatenation() { TfLiteStatus AddConcatenation() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_CONCATENATION, return AddBuiltin(BuiltinOperator_CONCATENATION,
*tflite::ops::micro::Register_CONCATENATION(), *tflite::ops::micro::Register_CONCATENATION(),
ParseOpData); ParseConcatenation);
} }
TfLiteStatus AddConv2D() { TfLiteStatus AddConv2D() {
@ -156,10 +150,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
} }
TfLiteStatus AddCos() { TfLiteStatus AddCos() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_COS, *tflite::ops::micro::Register_COS(), return AddBuiltin(BuiltinOperator_COS, *tflite::ops::micro::Register_COS(),
ParseOpData); ParseCos);
} }
TfLiteStatus AddDepthwiseConv2D() { TfLiteStatus AddDepthwiseConv2D() {
@ -175,17 +167,13 @@ class MicroMutableOpResolver : public MicroOpResolver {
} }
TfLiteStatus AddEqual() { TfLiteStatus AddEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_EQUAL, return AddBuiltin(BuiltinOperator_EQUAL,
*tflite::ops::micro::Register_EQUAL(), ParseOpData); *tflite::ops::micro::Register_EQUAL(), ParseEqual);
} }
TfLiteStatus AddFloor() { TfLiteStatus AddFloor() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_FLOOR, return AddBuiltin(BuiltinOperator_FLOOR,
*tflite::ops::micro::Register_FLOOR(), ParseOpData); *tflite::ops::micro::Register_FLOOR(), ParseFloor);
} }
TfLiteStatus AddFullyConnected() { TfLiteStatus AddFullyConnected() {
@ -195,18 +183,14 @@ class MicroMutableOpResolver : public MicroOpResolver {
} }
TfLiteStatus AddGreater() { TfLiteStatus AddGreater() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_GREATER, return AddBuiltin(BuiltinOperator_GREATER,
*tflite::ops::micro::Register_GREATER(), ParseOpData); *tflite::ops::micro::Register_GREATER(), ParseGreater);
} }
TfLiteStatus AddGreaterEqual() { TfLiteStatus AddGreaterEqual() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_GREATER_EQUAL, return AddBuiltin(BuiltinOperator_GREATER_EQUAL,
*tflite::ops::micro::Register_GREATER_EQUAL(), *tflite::ops::micro::Register_GREATER_EQUAL(),
ParseOpData); ParseGreaterEqual);
} }
TfLiteStatus AddL2Normalization() { TfLiteStatus AddL2Normalization() {
@ -274,10 +258,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
} }
TfLiteStatus AddMaxPool2D() { TfLiteStatus AddMaxPool2D() {
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
// function.
return AddBuiltin(BuiltinOperator_MAX_POOL_2D, return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
*tflite::ops::micro::Register_MAX_POOL_2D(), ParseOpData); *tflite::ops::micro::Register_MAX_POOL_2D(), ParsePool);
} }
TfLiteStatus AddMean() { TfLiteStatus AddMean() {

View File

@ -100,7 +100,6 @@ const std::map<string, string>& GetKnownBrokenTests() {
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"}, {R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"},
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"}, {R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"},
{R"(^\/sub.*dtype=tf\.int64)", "119126484"},
{R"(^\/div.*dtype=tf\.int64)", "119126484"}, {R"(^\/div.*dtype=tf\.int64)", "119126484"},
{R"(^\/mul.*dtype=tf\.int64)", "119126484"}, {R"(^\/mul.*dtype=tf\.int64)", "119126484"},
{R"(^\/add.*dtype=tf\.int64)", "119126484"}, {R"(^\/add.*dtype=tf\.int64)", "119126484"},

View File

@ -61,7 +61,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
{{OperatorType::kSub, 1}, "1.6.0"}, {{OperatorType::kSub, 1}, "1.6.0"},
{{OperatorType::kSub, 2}, "1.14.0"}, {{OperatorType::kSub, 2}, "1.14.0"},
{{OperatorType::kSub, 3}, "1.15.0"}, {{OperatorType::kSub, 3}, "1.15.0"},
{{OperatorType::kSub, 4}, "1.15.0"}, {{OperatorType::kSub, 4}, kPendingReleaseOpVersion},
{{OperatorType::kSub, 5}, kPendingReleaseOpVersion}, {{OperatorType::kSub, 5}, kPendingReleaseOpVersion},
{{OperatorType::kDiv, 1}, "1.6.0"}, {{OperatorType::kDiv, 1}, "1.6.0"},
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"}, {{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},

View File

@ -440,76 +440,6 @@ typedef struct TfLiteTensor {
// `dims_signature` contains [1, -1, -1, 3]). // `dims_signature` contains [1, -1, -1, 3]).
const TfLiteIntArray* dims_signature; const TfLiteIntArray* dims_signature;
} TfLiteTensor; } TfLiteTensor;
#else
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// NOTE: This flag is opt-in only at compile time.
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// A structure representing an instance of a node. // A structure representing an instance of a node.
// This structure only exhibits the inputs, outputs and user defined data, not // This structure only exhibits the inputs, outputs and user defined data, not
@ -547,6 +477,112 @@ typedef struct TfLiteNode {
// WARNING: This is an experimental interface that is subject to change. // WARNING: This is an experimental interface that is subject to change.
struct TfLiteDelegate* delegate; struct TfLiteDelegate* delegate;
} TfLiteNode; } TfLiteNode;
#else
// NOTE: This flag is opt-in only at compile time.
//
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
// contains only the minimum fields required to initialize and prepare a micro
// inference graph. The fields in this struct have been ordered from
// largest-to-smallest for optimal struct sizeof.
//
// This struct does not use:
// - allocation
// - buffer_handle
// - data_is_stale
// - delegate
// - dims_signature
// - name
// - sparsity
typedef struct TfLiteTensor {
// TODO(b/155784997): Consider consolidating these quantization fields:
// Quantization information. Replaces params field above.
TfLiteQuantization quantization;
// Quantization information.
TfLiteQuantizationParams params;
// A union of data pointers. The appropriate type should be used for a typed
// tensor based on `type`.
TfLitePtrUnion data;
// A pointer to a structure representing the dimensionality interpretation
// that the buffer should have. NOTE: the product of elements of `dims`
// and the element datatype size should be equal to `bytes` below.
TfLiteIntArray* dims;
// The number of bytes required to store the data of this Tensor. I.e.
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
// type is kTfLiteFloat32 and dims = {3, 2} then
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
size_t bytes;
// The data type specification for data stored in `data`. This affects
// what member of `data` union should be used.
TfLiteType type;
// How memory is mapped
// kTfLiteMmapRo: Memory mapped read only.
// i.e. weights
// kTfLiteArenaRw: Arena allocated read write memory
// (i.e. temporaries, outputs).
TfLiteAllocationType allocation_type;
// True if the tensor is a variable.
bool is_variable;
} TfLiteTensor;
// Specific reduced TfLiteNode struct for TF Micro runtime. This struct contains
// only the minimum fields required to represent a node.
//
// This struct does not use:
// - delegate
// - intermediates
// - temporaries
typedef struct TfLiteNode {
// Inputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* inputs;
// Outputs to this node expressed as indices into the simulator's tensors.
TfLiteIntArray* outputs;
// Opaque data provided by the node implementer through `Registration.init`.
void* user_data;
// Opaque data provided to the node if the node is a builtin. This is usually
// a structure defined in builtin_op_data.h
void* builtin_data;
// Custom initial data. This is the opaque data provided in the flatbuffer.
// WARNING: This is an experimental interface that is subject to change.
const void* custom_initial_data;
int custom_initial_data_size;
} TfLiteNode;
#endif // TF_LITE_STATIC_MEMORY
#ifndef TF_LITE_STATIC_MEMORY
// Free data memory of tensor `t`.
void TfLiteTensorDataFree(TfLiteTensor* t);
// Free quantization data.
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
// Free sparsity parameters.
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
// Free memory of tensor `t`.
void TfLiteTensorFree(TfLiteTensor* t);
// Set all of a tensor's fields (and free any previously allocated data).
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
TfLiteQuantizationParams quantization, char* buffer,
size_t size, TfLiteAllocationType allocation_type,
const void* allocation, bool is_variable,
TfLiteTensor* tensor);
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
// types other than kTfLiteDynamic will be ignored.
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
#endif // TF_LITE_STATIC_MEMORY
// WARNING: This is an experimental interface that is subject to change. // WARNING: This is an experimental interface that is subject to change.
// //

View File

@ -466,12 +466,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
op_sig.output_types.at(0) == TensorType_INT16) { op_sig.output_types.at(0) == TensorType_INT16) {
if (op_sig.options.addsub.pot_scale_int16) { if (op_sig.options.addsub.pot_scale_int16) {
return 5; return 5;
} else { }
}
if (!op_sig.input_types.empty() &&
op_sig.input_types.at(0) == TensorType_INT64) {
return 4; return 4;
} }
} if (op_sig.options.broadcast.need_broadcast &&
if (op_sig.options.addsub.need_broadcast && op_sig.options.broadcast.num_dims > 4) {
op_sig.options.addsub.num_dims > 4) {
return 3; return 3;
} }
if (op_sig.input_types.at(0) == TensorType_INT8) { if (op_sig.input_types.at(0) == TensorType_INT8) {

View File

@ -296,6 +296,14 @@ TEST(OpVersionTest, VersioningSubTest) {
SimpleVersioningTest(BuiltinOperator_SUB); SimpleVersioningTest(BuiltinOperator_SUB);
} }
TEST(OpVersionTest, VersioningSub4Test) {
OpSignature fake_op_sig = {
.op = BuiltinOperator_SUB,
.input_types = std::vector<TensorType>{TensorType_INT64},
};
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
}
void SimpleMulVersioningTest(TensorType data_type, float multiplier, void SimpleMulVersioningTest(TensorType data_type, float multiplier,
int version) { int version) {
OpSignature fake_op_sig = { OpSignature fake_op_sig = {

View File

@ -79,6 +79,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
{{BuiltinOperator_SUB, 1}, "1.6.0"}, {{BuiltinOperator_SUB, 1}, "1.6.0"},
{{BuiltinOperator_SUB, 2}, "1.14.0"}, {{BuiltinOperator_SUB, 2}, "1.14.0"},
{{BuiltinOperator_SUB, 3}, kPendingReleaseVersion}, {{BuiltinOperator_SUB, 3}, kPendingReleaseVersion},
{{BuiltinOperator_SUB, 4}, kPendingReleaseVersion},
{{BuiltinOperator_DENSIFY, 1}, "2.2.0"}, {{BuiltinOperator_DENSIFY, 1}, "2.2.0"},
{{BuiltinOperator_DIV, 1}, "1.6.0"}, {{BuiltinOperator_DIV, 1}, "1.6.0"},
{{BuiltinOperator_DIV, 2}, kPendingReleaseVersion}, {{BuiltinOperator_DIV, 2}, kPendingReleaseVersion},

View File

@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
# This value changes every day with an automatic CL. It can be modified in code # This value changes every day with an automatic CL. It can be modified in code
# via `forward_compatibility_horizon()` or with the environment variable # via `forward_compatibility_horizon()` or with the environment variable
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date. # TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 6, 23) _FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 6, 24)
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS" _FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
_FORWARD_COMPATIBILITY_DATE_NUMBER = None _FORWARD_COMPATIBILITY_DATE_NUMBER = None

View File

@ -1116,6 +1116,8 @@ py_library(
"//tensorflow/python:framework_ops", "//tensorflow/python:framework_ops",
"//tensorflow/python:gradients_impl", "//tensorflow/python:gradients_impl",
"//tensorflow/python:init_ops", "//tensorflow/python:init_ops",
"//tensorflow/python:init_ops_v2",
"//tensorflow/python:math_ops",
"//tensorflow/python:training", "//tensorflow/python:training",
"//tensorflow/python:util", "//tensorflow/python:util",
"//tensorflow/python:variable_scope", "//tensorflow/python:variable_scope",
@ -1124,7 +1126,6 @@ py_library(
"//tensorflow/python/eager:backprop", "//tensorflow/python/eager:backprop",
"//tensorflow/python/eager:context", "//tensorflow/python/eager:context",
"//tensorflow/python/eager:test", "//tensorflow/python/eager:test",
"//tensorflow/python/keras/layers",
"//third_party/py/numpy", "//third_party/py/numpy",
], ],
) )

View File

@ -18,6 +18,7 @@ from __future__ import absolute_import
from __future__ import division from __future__ import division
from __future__ import print_function from __future__ import print_function
import functools
import os import os
import tempfile import tempfile
@ -38,11 +39,12 @@ from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors from tensorflow.python.framework import errors
from tensorflow.python.framework import ops from tensorflow.python.framework import ops
from tensorflow.python.framework import test_util from tensorflow.python.framework import test_util
from tensorflow.python.keras.layers import core
from tensorflow.python.lib.io import tf_record from tensorflow.python.lib.io import tf_record
from tensorflow.python.ops import array_ops from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_math_ops
from tensorflow.python.ops import gradients_impl from tensorflow.python.ops import gradients_impl
from tensorflow.python.ops import init_ops from tensorflow.python.ops import init_ops
from tensorflow.python.ops import init_ops_v2
from tensorflow.python.ops import summary_ops_v2 as summary_ops from tensorflow.python.ops import summary_ops_v2 as summary_ops
from tensorflow.python.ops import variable_scope from tensorflow.python.ops import variable_scope
from tensorflow.python.ops import variables from tensorflow.python.ops import variables
@ -114,12 +116,19 @@ def _events_from_logdir(test_case, logdir):
class DistributionTestBase(test.TestCase): class DistributionTestBase(test.TestCase):
"""Some tests that should work with any DistributionStrategy.""" """Some tests that should work with any DistributionStrategy."""
def _create_variable_like_keras_dense_layer(self, name, shape, dtype):
initializer = functools.partial(
init_ops_v2.GlorotUniform(), shape, dtype=dtype)
return variables.Variable(
initial_value=initializer, name=name, trainable=True)
def _test_minimize_loss_eager(self, d): def _test_minimize_loss_eager(self, d):
with d.scope(): with d.scope():
l = core.Dense(1, use_bias=False) kernel = self._create_variable_like_keras_dense_layer(
name="kernel", shape=(1, 1), dtype=dtypes.float32)
def loss(x): def loss(x):
y = array_ops.reshape(l(x), []) - array_ops.identity(1.) y = array_ops.reshape(
gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
return y * y return y * y
# TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a # TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
# common `implicit_grad` function and put it in DistributionStrategy. # common `implicit_grad` function and put it in DistributionStrategy.
@ -173,10 +182,12 @@ class DistributionTestBase(test.TestCase):
ops.Graph().as_default(), \ ops.Graph().as_default(), \
self.cached_session(config=config) as sess, \ self.cached_session(config=config) as sess, \
d.scope(): d.scope():
l = core.Dense(1, use_bias=False) kernel = self._create_variable_like_keras_dense_layer(
name="kernel", shape=(1, 1), dtype=dtypes.float32)
def loss(x): def loss(x):
y = array_ops.reshape(l(x), []) - array_ops.identity(1.) y = array_ops.reshape(
gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
return y * y return y * y
grad_fn = backprop.implicit_grad(loss) grad_fn = backprop.implicit_grad(loss)

View File

@ -890,13 +890,13 @@ class Context(object):
@property @property
def executor(self): def executor(self):
ensure_initialized() self.ensure_initialized()
return executor.Executor( return executor.Executor(
pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle)) pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
@executor.setter @executor.setter
def executor(self, e): def executor(self, e):
ensure_initialized() self.ensure_initialized()
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle()) pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
@property @property

View File

@ -1192,7 +1192,7 @@ class _TapeGradientFunctions(object):
def _wrap_backward_function(self, forward_graph, backward, outputs): def _wrap_backward_function(self, forward_graph, backward, outputs):
"""Create a backward function given `outputs` from the forward function.""" """Create a backward function given `outputs` from the forward function."""
capture_mapping = dict( capture_mapping = dict(
zip([ops.tensor_id(t) for t in forward_graph.outputs], outputs)) zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
remapped_captures = [ remapped_captures = [
capture_mapping.get(ops.tensor_id(capture), capture) capture_mapping.get(ops.tensor_id(capture), capture)
for capture in backward.captured_inputs for capture in backward.captured_inputs
@ -1489,9 +1489,8 @@ class ConcreteFunction(object):
self._captured_closures = self._func_graph.deferred_external_captures self._captured_closures = self._func_graph.deferred_external_captures
structured_outputs = self._func_graph.structured_outputs structured_outputs = self._func_graph.structured_outputs
self._ndarrays_list = ( self._ndarrays_list = (
isinstance(structured_outputs, (list, tuple)) and isinstance(structured_outputs, (list, tuple)) and structured_outputs and
structured_outputs and all(isinstance(o, np_arrays.ndarray) for o in structured_outputs))
all([isinstance(o, np_arrays.ndarray) for o in structured_outputs]))
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray) self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
# function_spec defines the structured signature. # function_spec defines the structured signature.
@ -2199,6 +2198,14 @@ class ConcreteFunction(object):
assert self._function_spec is not None assert self._function_spec is not None
arg_specs, kwarg_specs = self.structured_input_signature arg_specs, kwarg_specs = self.structured_input_signature
arg_names = list(self._function_spec.arg_names) arg_names = list(self._function_spec.arg_names)
# If an explicit input_signature is provided to @tf.function, then any
# arguments with defaults that are not covered by that explicit signature
# are simply dropped from the signature.
# TODO(b/159639913) Look into whether dropping arguments with default values
# from the signature is the right thing to do.
arg_names = arg_names[:len(arg_specs)]
if default_values: if default_values:
for i in range(len(arg_names)): for i in range(len(arg_names)):
if not _contains_type_spec(arg_specs[i]): if not _contains_type_spec(arg_specs[i]):
@ -2248,6 +2255,14 @@ class ConcreteFunction(object):
lines = [self._structured_signature_summary(default_values=True)] lines = [self._structured_signature_summary(default_values=True)]
arg_specs, kwarg_specs = self.structured_input_signature arg_specs, kwarg_specs = self.structured_input_signature
names = list(self._function_spec.arg_names) names = list(self._function_spec.arg_names)
# If an explicit input_signature is provided to @tf.function, then any
# arguments with defaults that are not covered by that explicit signature
# are simply dropped from the signature.
# TODO(b/159639913) Look into whether dropping arguments with default values
# from the signature is the right thing to do.
names = names[:len(arg_specs)]
names.extend(sorted(kwarg_specs)) names.extend(sorted(kwarg_specs))
specs = list(arg_specs) + list(kwarg_specs.values()) specs = list(arg_specs) + list(kwarg_specs.values())
# note: we can skip bound args, since we already displayed thier bound # note: we can skip bound args, since we already displayed thier bound
@ -2855,7 +2870,6 @@ class Function(object):
graph_function, _, _ = self._maybe_define_function(args, kwargs) graph_function, _, _ = self._maybe_define_function(args, kwargs)
return graph_function return graph_function
# XX TODO: make sure we fix up this path as well!?
def _get_concrete_function_internal(self, *args, **kwargs): def _get_concrete_function_internal(self, *args, **kwargs):
"""Bypasses error checking when getting a graph function.""" """Bypasses error checking when getting a graph function."""
graph_function = self._get_concrete_function_internal_garbage_collected( graph_function = self._get_concrete_function_internal_garbage_collected(

Some files were not shown because too many files have changed in this diff Show More