Merge branch 'master' into addsub_16x8
This commit is contained in:
commit
1bcc3bc41c
@ -54,7 +54,7 @@ Status ProcessInputs(
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
input_tensors->reserve(ninputs);
|
||||
for (int i = 0; i < ninputs; ++i) {
|
||||
Node* node = &inputs[i].oper->node;
|
||||
Node* node = inputs[i].oper ? &inputs[i].oper->node : nullptr;
|
||||
int idx = inputs[i].index;
|
||||
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
@ -90,7 +90,7 @@ Status ProcessOutputs(const TF_Graph* fn_body, const char* fn_name,
|
||||
TF_EXCLUSIVE_LOCKS_REQUIRED(fn_body->mu) {
|
||||
output_tensors->reserve(noutputs);
|
||||
for (int i = 0; i < noutputs; ++i) {
|
||||
Node* node = &outputs[i].oper->node;
|
||||
Node* node = outputs[i].oper ? &outputs[i].oper->node : nullptr;
|
||||
int idx = outputs[i].index;
|
||||
TF_RETURN_WITH_CONTEXT_IF_ERROR(
|
||||
fn_body->graph.IsValidOutputTensor(node, idx),
|
||||
|
@ -19,6 +19,7 @@ tf_cc_shared_object(
|
||||
cc_library(
|
||||
name = "gcs_filesystem_impl",
|
||||
srcs = ["gcs_filesystem.cc"],
|
||||
hdrs = ["gcs_filesystem.h"],
|
||||
copts = select({
|
||||
"//conditions:default": [],
|
||||
"//tensorflow:windows": get_win_copts(),
|
||||
@ -55,14 +56,9 @@ tf_cc_test(
|
||||
"notap",
|
||||
],
|
||||
deps = [
|
||||
":gcs_helper",
|
||||
"//tensorflow/c:env",
|
||||
"//tensorflow/c:tf_status",
|
||||
":gcs_filesystem_impl",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/filesystem:filesystem_interface",
|
||||
"//tensorflow/core/platform:stacktrace_handler",
|
||||
"//tensorflow/core/platform:test",
|
||||
"@com_github_googlecloudplatform_google_cloud_cpp//:storage_client",
|
||||
"@com_google_absl//absl/strings",
|
||||
],
|
||||
)
|
||||
|
@ -12,26 +12,16 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
|
||||
|
||||
#include <stdlib.h>
|
||||
#include <string.h>
|
||||
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/env.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_helper.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
#ifdef TF_GCS_FILESYSTEM_TEST
|
||||
// For testing purpose, we expose some functions.
|
||||
#define TF_STATIC
|
||||
#else
|
||||
// Otherwise, we don't expose any symbol.
|
||||
#define TF_STATIC static
|
||||
#endif
|
||||
|
||||
// Implementation of a filesystem for GCS environments.
|
||||
// This filesystem will support `gs://` URI schemes.
|
||||
namespace gcs = google::cloud::storage;
|
||||
@ -48,8 +38,8 @@ static inline void TF_SetStatusFromGCSStatus(
|
||||
static void* plugin_memory_allocate(size_t size) { return calloc(1, size); }
|
||||
static void plugin_memory_free(void* ptr) { free(ptr); }
|
||||
|
||||
static void ParseGCSPath(absl::string_view fname, bool object_empty_ok,
|
||||
char** bucket, char** object, TF_Status* status) {
|
||||
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
|
||||
char** object, TF_Status* status) {
|
||||
size_t scheme_end = fname.find("://") + 2;
|
||||
if (fname.substr(0, scheme_end + 1) != "gs://") {
|
||||
TF_SetStatus(status, TF_INVALID_ARGUMENT,
|
||||
@ -130,7 +120,7 @@ namespace tf_read_only_memory_region {
|
||||
namespace tf_gcs_filesystem {
|
||||
|
||||
// TODO(vnvo2409): Add lazy-loading and customizing parameters.
|
||||
TF_STATIC void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
google::cloud::StatusOr<gcs::Client> client =
|
||||
gcs::Client::CreateDefaultClient();
|
||||
if (!client) {
|
||||
@ -143,14 +133,14 @@ TF_STATIC void Init(TF_Filesystem* filesystem, TF_Status* status) {
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
static void Cleanup(TF_Filesystem* filesystem) {
|
||||
void Cleanup(TF_Filesystem* filesystem) {
|
||||
plugin_memory_free(filesystem->plugin_filesystem);
|
||||
}
|
||||
|
||||
// TODO(vnvo2409): Implement later
|
||||
|
||||
static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
char* bucket;
|
||||
char* object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
@ -166,8 +156,8 @@ static void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_SetStatus(status, TF_OK, "");
|
||||
}
|
||||
|
||||
static void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status) {
|
||||
char* bucket;
|
||||
char* object;
|
||||
ParseGCSPath(path, false, &bucket, &object, status);
|
||||
|
@ -0,0 +1,35 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "google/cloud/storage/client.h"
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/tf_status.h"
|
||||
|
||||
void ParseGCSPath(absl::string_view fname, bool object_empty_ok, char** bucket,
|
||||
char** object, TF_Status* status);
|
||||
|
||||
namespace tf_gcs_filesystem {
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
void Cleanup(TF_Filesystem* filesystem);
|
||||
void NewWritableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
void NewAppendableFile(const TF_Filesystem* filesystem, const char* path,
|
||||
TF_WritableFile* file, TF_Status* status);
|
||||
} // namespace tf_gcs_filesystem
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_FILESYSTEM_PLUGINS_GCS_GCS_FILESYSTEM_H_
|
@ -12,18 +12,14 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/filesystem/filesystem_interface.h"
|
||||
#include "tensorflow/c/experimental/filesystem/plugins/gcs/gcs_filesystem.h"
|
||||
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/stacktrace_handler.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
#define ASSERT_TF_OK(x) ASSERT_EQ(TF_OK, TF_GetCode(x))
|
||||
|
||||
// Forward declaration
|
||||
namespace tf_gcs_filesystem {
|
||||
void Init(TF_Filesystem* filesystem, TF_Status* status);
|
||||
}
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
@ -38,7 +34,7 @@ class GCSFilesystemTest : public ::testing::Test {
|
||||
}
|
||||
void TearDown() override {
|
||||
TF_DeleteStatus(status_);
|
||||
// TODO(vnvo2409): Add filesystem cleanup
|
||||
tf_gcs_filesystem::Cleanup(filesystem_);
|
||||
delete filesystem_;
|
||||
}
|
||||
|
||||
|
@ -138,6 +138,11 @@ bool IsI32Type(Type element_type) {
|
||||
return element_type.isInteger(32) && !element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true when the given element_type is I64.
|
||||
bool IsI64Type(Type element_type) {
|
||||
return element_type.isInteger(64) && !element_type.isUnsignedInteger();
|
||||
}
|
||||
|
||||
// Return true if the given Add operation has the CPU kernel supported shapes.
|
||||
bool VerifyAddOpShapeConstraints(AddOp op) {
|
||||
auto element_type = getElementTypeOrSelf(op.output().getType());
|
||||
@ -174,7 +179,8 @@ bool VerifySubOpShapeConstraints(SubOp op) {
|
||||
// Allows F32, QUI8, and QI16 outputs when the operands have valid shapes,
|
||||
// which are broadcastable shapes up to five dimension or have same shapes.
|
||||
if (element_type.isF32() || IsI32Type(element_type) ||
|
||||
IsQUI8Type(element_type) || IsQI16Type(element_type)) {
|
||||
IsI64Type(element_type) || IsQUI8Type(element_type) ||
|
||||
IsQI16Type(element_type)) {
|
||||
return VerifyOperandsHaveSameShapesOrBroadcastableShape(
|
||||
/*op=*/op.getOperation(), /*indices=*/ArrayRef<unsigned>{0, 1},
|
||||
/*max_bcast_rank=*/5);
|
||||
|
@ -2864,11 +2864,11 @@ def TFL_SubOp : TFL_Op<"sub", [
|
||||
}];
|
||||
|
||||
let arguments = (
|
||||
ins TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$rhs,
|
||||
ins TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$lhs,
|
||||
TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$rhs,
|
||||
TFL_AFAttr:$fused_activation_function);
|
||||
|
||||
let results = (outs TFL_TensorOf<[F32, I32, QI8, QUI8, QI16]>:$output);
|
||||
let results = (outs TFL_TensorOf<[F32, I32, I64, QI8, QUI8, QI16]>:$output);
|
||||
|
||||
let hasFolder = 1;
|
||||
|
||||
|
@ -9,6 +9,15 @@ func @add(%arg0: tensor<1xf32>, %arg1: tensor<1xf32>) -> tensor<1xf32> {
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
func @sub(%arg0: tensor<1xi64>, %arg1: tensor<1xi64>) -> tensor<1xi64> {
|
||||
%0 = "tf.Sub"(%arg0, %arg1) : (tensor<1xi64>, tensor<1xi64>) -> tensor<1xi64>
|
||||
return %0: tensor<1xi64>
|
||||
|
||||
// CHECK-LABEL: sub
|
||||
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "NONE"} : tensor<1xi64>
|
||||
// CHECK: return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testAddHighDimsHaveSameShape
|
||||
func @testAddHighDimsHaveSameShape(%arg0: tensor<1x2x3x4x5x6x7x8xi32>, %arg1: tensor<1x2x3x4x5x6x7x8xi32>) -> tensor<1x2x3x4x5x6x7x8xi32> {
|
||||
// CHECK: tfl.add %arg0, %arg1 {fused_activation_function = "NONE"}
|
||||
|
@ -269,6 +269,14 @@ func @testSub(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
return %0#0 : tensor<? x i32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testSubInt64
|
||||
func @testSubInt64(tensor<? x i64>, tensor<? x i64>) -> tensor<? x i64> {
|
||||
^bb0(%arg0: tensor<? x i64>, %arg1: tensor<? x i64>):
|
||||
// CHECK: tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"}
|
||||
%0 = tfl.sub %arg0, %arg1 {fused_activation_function = "RELU6"} : tensor<? x i64>
|
||||
return %0#0 : tensor<? x i64>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: testMul
|
||||
func @testMul(tensor<? x i32>, tensor<? x i32>) -> tensor<? x i32> {
|
||||
^bb0(%arg0: tensor<? x i32>, %arg1: tensor<? x i32>):
|
||||
|
@ -346,6 +346,7 @@ func @replication(%arg0: tensor<i1>, %arg1: tensor<i32>, %arg2: tensor<f32>) ->
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK-DAG: [%[[ARG_0]], %[[OP_A]]] as %[[RI_0:[a-z0-9]*]]: tensor<i1>
|
||||
// CHECK-DAG: [%[[OP_B]], %[[ARG_1]]] as %[[RI_1:[a-z0-9]*]]: tensor<i32>
|
||||
// CHECK-NOT: _replicated_input_indices
|
||||
// CHECK-SAME: n = 2 : i32
|
||||
// CHECK-NEXT: %[[CLUSTER:[0-9]*]]:2 = "tf_device.cluster"() ( {
|
||||
// CHECK: %[[OP_D:[0-9]*]] = "tf.opD"(%[[RI_0]], %[[RI_1]], %[[ARG_2]], %[[OP_C]])
|
||||
@ -382,6 +383,46 @@ func @sort_replicated_input(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<
|
||||
// CHECK-DAG: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-DAG: [%[[ARG_3]], %[[ARG_3]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-DAG: [%[[ARG_5]], %[[ARG_5]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: _replicated_input_indices = [0, 1, 2, -1, -1, -1]
|
||||
|
||||
|
||||
// Test TPUReplicatedInputs with non contiguous `index` attributes.
|
||||
// CHECK-LABEL: func @non_contigous_indices
|
||||
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<i1>, %[[ARG_1:.*]]: tensor<i1>, %[[ARG_2:.*]]: tensor<i1>)
|
||||
func @non_contigous_indices(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 8 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
|
||||
%1 = "tf.TPUReplicatedInput"(%arg1, %arg1) : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opB"(%1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
|
||||
%2 = "tf.TPUReplicatedInput"(%arg2, %arg2) {index = 2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opC"(%2) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG_2]], %[[ARG_2]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: [%[[ARG_0]], %[[ARG_0]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: [%[[ARG_1]], %[[ARG_1]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: _replicated_input_indices = [2, 8, -1]
|
||||
|
||||
|
||||
// Test that the `is_mirrored_variable` attribute is preserved in the
|
||||
// tf_device.replicate op.
|
||||
// CHECK-LABEL: func @mirrored_variables
|
||||
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_2:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_3:.*]]: tensor<!tf.resource<tensor<32xf32>>>)
|
||||
func @mirrored_variables(%arg0: tensor<!tf.resource<tensor<32xf32>>>, %arg1: tensor<!tf.resource<tensor<32xf32>>>, %arg2: tensor<!tf.resource<tensor<32xf32>>>, %arg3: tensor<!tf.resource<tensor<32xf32>>>) {
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
|
||||
%1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
|
||||
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device"} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: _mirrored_variable_indices = [1]
|
||||
// CHECK-SAME: _replicated_input_indices = [0, 1]
|
||||
|
||||
|
||||
// -----
|
||||
@ -407,8 +448,10 @@ func @bad_num_replicas() {
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test that functions without TPUReplicateMetadata op are skipped without
|
||||
// error
|
||||
// CHECK-LABEL: func @missing_metadata_op
|
||||
@ -483,22 +526,9 @@ func @leftover_replicated_output(%arg0: tensor<i1>) {
|
||||
// -----
|
||||
|
||||
|
||||
// Test bad TPUReplicatedInput positive `index` attribute.
|
||||
func @bad_positive_index_input(%arg0: tensor<i1>) {
|
||||
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got 1}}
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test bad TPUReplicatedInput negative `index` attribute.
|
||||
func @bad_negative_index_input(%arg0: tensor<i1>) {
|
||||
// expected-error@+1 {{'tf.TPUReplicatedInput' index is not in range [-1, 1), got -2}}
|
||||
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires index to be at least -1, but got -2}}
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = -2 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
@ -509,33 +539,12 @@ func @bad_negative_index_input(%arg0: tensor<i1>) {
|
||||
// -----
|
||||
|
||||
|
||||
// Test TPUReplicatedInput with conflicting `index` attribute. This will result
|
||||
// in gaps in the TPUReplicatedInput ordering.
|
||||
// Test TPUReplicatedInput with conflicting `index` attribute.
|
||||
func @input_index_gaps(%arg0: tensor<i1>) {
|
||||
// expected-error@+1 {{failed to sort 'tf.TPUReplicatedInput' ops, gap(s) found in indices}}
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
// expected-error@+1 {{'tf.TPUReplicatedInput' op requires indices to be unique, but found multiple 'tf.TPUReplicatedInput' ops with index 1}}
|
||||
%1 = "tf.TPUReplicatedInput"(%arg0, %arg0) {index = 1 : i64} : (tensor<i1>, tensor<i1>) -> tensor<i1>
|
||||
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device", name = "name"} : (tensor<i1>, tensor<i1>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
|
||||
// -----
|
||||
|
||||
|
||||
// Test that the `is_mirrored_variable` attribute is preserved in the
|
||||
// tf_device.replicate op.
|
||||
// CHECK-LABEL: func @mirrored_variables
|
||||
// CHECK-SAME: (%[[ARG_0:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_2:.*]]: tensor<!tf.resource<tensor<32xf32>>>, %[[ARG_3:.*]]: tensor<!tf.resource<tensor<32xf32>>>)
|
||||
func @mirrored_variables(%arg0: tensor<!tf.resource<tensor<32xf32>>>, %arg1: tensor<!tf.resource<tensor<32xf32>>>, %arg2: tensor<!tf.resource<tensor<32xf32>>>, %arg3: tensor<!tf.resource<tensor<32xf32>>>) {
|
||||
%0 = "tf.TPUReplicatedInput"(%arg0, %arg1) {index = 0 : i64} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
|
||||
%1 = "tf.TPUReplicatedInput"(%arg2, %arg3) {index = 1 : i64, is_mirrored_variable = true} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> tensor<!tf.resource<tensor<32xf32>>>
|
||||
"tf.opA"(%0, %1) {_tpu_replicate = "replicate", device = "device"} : (tensor<!tf.resource<tensor<32xf32>>>, tensor<!tf.resource<tensor<32xf32>>>) -> ()
|
||||
"tf.TPUReplicateMetadata"() {_tpu_replicate = "replicate", device = "device", num_replicas = 2, topology = "topology"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK: tf_device.replicate
|
||||
// CHECK-SAME: [%[[ARG_0]], %[[ARG_1]]] as %{{[a-z0-9]*}}
|
||||
// CHECK-SAME: _mirrored_variable_indices = [1]
|
||||
|
@ -9,7 +9,7 @@
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @single_arg_single_shape
|
||||
func @single_arg_single_shape(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func0, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -36,7 +36,7 @@ func @func0(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
// padding_arg_index: 2
|
||||
// CHECK-LABEL: func @single_arg_multiple_shapes
|
||||
func @single_arg_multiple_shapes(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>) {_replicated_input_indices = [0, 1, 2], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1, %ri_2) {func = @func1, padding_map = ["\10\02\18\01", "\10\03\18\02"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -68,7 +68,7 @@ func @func1(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
// padding_arg_index: 3
|
||||
// CHECK-LABEL: func @multiple_args
|
||||
func @multiple_args(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>, [%arg0, %arg0] as %ri_4: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>, [%arg0, %arg0] as %ri_4: tensor<i1>) {_replicated_input_indices = [0, 1, 2, 3, 4], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1, %ri_2, %ri_3, %ri_4) {func = @func2, padding_map = ["\10\02\18\01", "\10\03\18\02", "\08\04\10\01\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -89,7 +89,7 @@ func @func2(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>, %arg3: tens
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @remap_indices
|
||||
func @remap_indices(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func3, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -124,7 +124,7 @@ func @func4(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
// Test encapsulated function is not modified when there are no padding maps.
|
||||
// CHECK-LABEL: func @no_padding_map
|
||||
func @no_padding_map(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func5} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -140,7 +140,7 @@ func @func5(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
// Test encapsulated function is not modified when padding maps is empty.
|
||||
// CHECK-LABEL: func @empty_padding_map
|
||||
func @empty_padding_map(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_1, %arg0, %ri_0) {func = @func6, padding_map = []} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -161,7 +161,7 @@ func @func6(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @unused_padding_map
|
||||
func @unused_padding_map(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_1) {func = @func7, padding_map = ["\10\02\18\01"]} : (tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -187,7 +187,7 @@ func @func7(%arg0: tensor<i1>) {
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 3
|
||||
func @missing_padding_arg(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>) {n = 2 : i32} {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>, [%arg0, %arg0] as %ri_2: tensor<i1>, [%arg0, %arg0] as %ri_3: tensor<i1>) {_replicated_input_indices = [0, 1, 2, 3], n = 2 : i32} {
|
||||
// expected-warning@+1 {{bad 'padding_map' attribute at index 0, unused padding_arg_index 1}}
|
||||
"tf_device.cluster_func"(%ri_0, %ri_2, %ri_3) {func = @func8, padding_map = ["\10\02\18\01", "\08\02\10\02\18\03"]} : (tensor<i1>, tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
@ -201,11 +201,55 @@ func @func8(%arg0: tensor<i1>, %arg1: tensor<i1>, %arg2: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test tf_device.replicate with missing _replicated_input_indices does no
|
||||
// transformation.
|
||||
//
|
||||
// Padding map "\10\02\18\01":
|
||||
// arg_index: 0
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
// CHECK-LABEL: func @missing_replicated_input_indices
|
||||
func @missing_replicated_input_indices(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func9, padding_map = ["\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func9
|
||||
// CHECK-NOT: xla_hlo.padding_map
|
||||
func @func9(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// Test single argument with padding map lifted to associated encapsulated
|
||||
// function.
|
||||
//
|
||||
// Padding map "\08\08\10\06\18\02"
|
||||
// arg_index: 8
|
||||
// shape_index: 6
|
||||
// padding_arg_index: 2
|
||||
// CHECK-LABEL: func @non_contigous_indices
|
||||
func @non_contigous_indices(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [2, 8], n = 2 : i32} {
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @func10, padding_map = ["\08\08\10\06\18\02"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @func10
|
||||
// CHECK-SAME: (%{{[a-z0-9]+}}: tensor<i1>, %{{[a-z0-9]+}}: tensor<i1> {xla_hlo.padding_map = {padding_arg_indices = [0 : i32], shape_indices = [6 : i32]}})
|
||||
func @func10(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test bad padding map attribute (not an array).
|
||||
func @bad_padding_map() {
|
||||
tf_device.replicate {n = 2 : i32} {
|
||||
tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op requires 'padding_map' array attribute}}
|
||||
"tf_device.cluster_func"() {func = @_func, padding_map = 0 : i32} : () -> ()
|
||||
tf_device.return
|
||||
@ -221,7 +265,7 @@ func @_func() {
|
||||
|
||||
// Test bad padding map attribute (element in array is not a string).
|
||||
func @bad_padding_map_element() {
|
||||
tf_device.replicate {n = 2 : i32} {
|
||||
tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, not a string}}
|
||||
"tf_device.cluster_func"() {func = @_func, padding_map = [0 : i32]} : () -> ()
|
||||
tf_device.return
|
||||
@ -237,7 +281,7 @@ func @_func() {
|
||||
|
||||
// Test unparsable padding map.
|
||||
func @bad_padding_map_proto() {
|
||||
tf_device.replicate {n = 2 : i32} {
|
||||
tf_device.replicate {_replicated_input_indices = [], n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, failed to parse 'z' as tensorflow::tpu::PaddingMap}}
|
||||
"tf_device.cluster_func"() {func = @_func, padding_map = ["z"]} : () -> ()
|
||||
tf_device.return
|
||||
@ -258,8 +302,8 @@ func @_func() {
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
func @negative_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got -1}}
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be nonnegative, but got -1}}
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\FF\FF\FF\FF\FF\FF\FF\FF\FF\01\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -272,27 +316,6 @@ func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
|
||||
// -----
|
||||
|
||||
// Test out of bound arg index.
|
||||
//
|
||||
// Padding map "\08\02\10\02\18\01":
|
||||
// arg_index: 2
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 1
|
||||
func @bad_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, arg_index must be in [0, 2), got 2}}
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\02\10\02\18\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test negative padding arg index.
|
||||
//
|
||||
// Padding map "\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01":
|
||||
@ -300,8 +323,8 @@ func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
// shape_index: 2
|
||||
// padding_arg_index: -1
|
||||
func @negative_padding_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got -1}}
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {_replicated_input_indices = [0, 1], n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be nonnegative, but got -1}}
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\FF\FF\FF\FF\FF\FF\FF\FF\FF\01"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
@ -311,24 +334,3 @@ func @negative_padding_arg_index(%arg0: tensor<i1>) {
|
||||
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
// Test out of bound padding arg index.
|
||||
//
|
||||
// Padding map "\08\01\10\02\18\02":
|
||||
// arg_index: 1
|
||||
// shape_index: 2
|
||||
// padding_arg_index: 2
|
||||
func @bad_padding_arg_index(%arg0: tensor<i1>) {
|
||||
tf_device.replicate([%arg0, %arg0] as %ri_0: tensor<i1>, [%arg0, %arg0] as %ri_1: tensor<i1>) {n = 2 : i32} {
|
||||
// expected-error@+1 {{'tf_device.cluster_func' op bad 'padding_map' attribute at index 0, padding_arg_index must be in [0, 2), got 2}}
|
||||
"tf_device.cluster_func"(%ri_0, %ri_1) {func = @_func, padding_map = ["\08\01\10\02\18\02"]} : (tensor<i1>, tensor<i1>) -> ()
|
||||
tf_device.return
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func @_func(%arg0: tensor<i1>, %arg1: tensor<i1>) {
|
||||
return
|
||||
}
|
||||
|
@ -2,460 +2,455 @@
|
||||
|
||||
// Tests that missing `_xla_outside_compilation` attribute value results in an error.
|
||||
|
||||
func @missing_outside_compilation_attribute() -> () {
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
// expected-error@+1 {{attribute '_xla_outside_compilation' is empty}}
|
||||
"tf.B"() {_xla_outside_compilation = ""} : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:worker/replica:0/task:0/device:CPU:0", "/job:worker/replica:0/task:0/device:TPU_SYSTEM:0", "/job:worker/replica:0/task:0/device:TPU:0"]} {
|
||||
// Tests that TPU cluster with no outside compilation does not generate parallel_execute.
|
||||
|
||||
// -----
|
||||
// CHECK-LABEL: func @no_outside_compilation
|
||||
func @no_outside_compilation() -> tensor<?xi32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests that TPU cluster with no outside compilation does not generate parallel_execute.
|
||||
// CHECK-NOT: "tf_device.parallel_execute"
|
||||
|
||||
// CHECK-LABEL: func @no_outside_compilation
|
||||
func @no_outside_compilation() -> tensor<?xi32> {
|
||||
%0 = "tf_device.cluster"() ( {
|
||||
%1 = "tf.A"() : () -> tensor<?xi32>
|
||||
%2 = "tf.B"(%1) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
return %0 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
|
||||
|
||||
// CHECK-NOT: "tf_device.parallel_execute"
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with no input or output dependecies.
|
||||
|
||||
// CHECK-LABEL: func @nodep_single_outside_compilation
|
||||
func @nodep_single_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK: cluster_attr = "cluster_attr"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies.
|
||||
|
||||
// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation
|
||||
func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.D"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.E"
|
||||
// CHECK: cluster_attr = "cluster_attr"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with no input or output dependecies.
|
||||
|
||||
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
||||
func @nodep_multiple_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-COUNT-2: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {cluster_attr = "cluster_attr"} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
|
||||
|
||||
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
|
||||
func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
// CHECK-LABEL: func @nodep_single_outside_compilation
|
||||
func @nodep_single_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "/job:worker/replica:0/task:0/device:CPU:0"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%3 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
"tf.C"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with multiple ops and no input or output dependecies.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
|
||||
|
||||
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
|
||||
func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2, %3 = "tf_device.cluster"() ( {
|
||||
%4 = "tf.A"() : () -> tensor<?xf32>
|
||||
// CHECK-LABEL: func @nodep_single_cluster_multiple_ops_outside_compilation
|
||||
func @nodep_single_cluster_multiple_ops_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: "tf.C"
|
||||
// CHECK-NEXT: "tf.D"
|
||||
// CHECK-NOT: _xla_outside_compilation
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK-NEXT: "tf.A"
|
||||
// CHECK-NEXT: "tf.E"
|
||||
// CHECK: device_assignment = [], num_cores_per_replica = 1 : i64, topology = ""
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%5 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
||||
"tf.C"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
// Tests extraction of a multiple outside compiled clusters with no input or output dependecies.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single device->host input.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
|
||||
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
// CHECK-LABEL: func @nodep_multiple_outside_compilation
|
||||
func @nodep_multiple_outside_compilation() -> () {
|
||||
// CHECK: "tf_device.parallel_execute"
|
||||
// CHECK-COUNT-2: "tf_device.launch"
|
||||
// CHECK: "tf_device.cluster"
|
||||
"tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
"tf.C"() : () -> ()
|
||||
"tf.D"() {_xla_outside_compilation = "cluster2"} : () -> ()
|
||||
"tf.E"() : () -> ()
|
||||
tf_device.return
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> ()
|
||||
return
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single TPU cluster return.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single host->device output.
|
||||
// CHECK-LABEL: func @single_tpu_return_single_outside_compilation
|
||||
func @single_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK-NEXT: "tf.B"
|
||||
// CHECK-NEXT: tf_device.return
|
||||
// CHECK-NEXT: device = "TPU_REPLICATED_HOST"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]] = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
"tf.A"() : () -> ()
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%3 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %3 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation
|
||||
func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"()
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with multiple TPU cluster return.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster host output returned by TPU cluster.
|
||||
// CHECK-LABEL: func @multiple_tpu_return_single_outside_compilation
|
||||
func @multiple_tpu_return_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xf32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:4 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[TPU_CLUSTER_OUTPUT:[0-9]*]]:2 = "tf_device.cluster"
|
||||
// CHECK: tf_device.return
|
||||
// CHECK: tf_device.return %[[TPU_CLUSTER_OUTPUT]]
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]
|
||||
%1:4 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2, %3 = "tf_device.cluster"() ( {
|
||||
%4 = "tf.A"() : () -> tensor<?xf32>
|
||||
"tf.B"() {_xla_outside_compilation = "cluster1"} : () -> ()
|
||||
%5 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4, %5 : tensor<?xf32>, tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xf32>, tensor<?xi32>)
|
||||
tf_device.return %2, %3 : tensor<?xf32>, tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @return_host_output_outside_compilation
|
||||
func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: tf_device.return %[[HOST_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xf32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single device->host input.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with single input/output.
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_single_outside_compilation
|
||||
func @single_outside_compiled_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation
|
||||
func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single host->device output.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple input/output.
|
||||
// CHECK-LABEL: func @single_outside_compiled_output_single_outside_compilation
|
||||
func @single_outside_compiled_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"()
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"()
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() {_xla_outside_compilation = "cluster1"} : () -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
|
||||
func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
%5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %8 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster host output returned by TPU cluster.
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with input/output.
|
||||
// CHECK-LABEL: func @return_host_output_outside_compilation
|
||||
func @return_host_output_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: tf_device.return %[[HOST_OUTPUT]]
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%3) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation
|
||||
func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
|
||||
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %7 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with single input/output.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_output_single_outside_compilation
|
||||
func @single_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[HOST_OUTPUT]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @mixed_input_single_outside_compilation
|
||||
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with multiple input/output.
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with single device->host input.
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_input_output_single_outside_compilation
|
||||
func @multiple_outside_compiled_input_output_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]]:2 = "tf.C"(%[[RECV_OUTPUT]]#0, %[[RECV_OUTPUT]]#1)
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]]#0, %[[B_OUTPUT]]#1, %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: %[[HOST_OUTPUT:[0-9]*]]:2 = "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.D"(%[[HOST_OUTPUT]]#0)
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT]]#1)
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
%5, %6 = "tf.C"(%3, %4) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
%7 = "tf.D"(%5) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
%8 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %8 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
|
||||
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a multiple outside compiled clusters with input/output.
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
|
||||
// CHECK-LABEL: func @outside_compiled_input_output_multiple_outside_compilation
|
||||
func @outside_compiled_input_output_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT2]])
|
||||
// CHECK: %[[D_OUTPUT:[0-9]*]] = "tf.D"(%[[RECV_OUTPUT2]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[D_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT1]])
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[RECV_OUTPUT1]])
|
||||
// CHECK: "tf._XlaSendFromHost"(%[[B_OUTPUT]], %[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[HOST_OUTPUT1:[0-9]*]] = "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"(%[[HOST_OUTPUT1]])
|
||||
// CHECK: %[[HOST_OUTPUT2:[0-9]*]] = "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.E"(%[[HOST_OUTPUT2]])
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%6 = "tf.D"(%5) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%7 = "tf.E"(%6) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
tf_device.return %7 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
|
||||
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
// Tests extraction of a single outside compiled cluster with arg input and single device->host input.
|
||||
|
||||
// Tests only directly used results of tpu cluster are remapped with
|
||||
// parallel_execute.
|
||||
// CHECK-LABEL: func @mixed_input_single_outside_compilation
|
||||
func @mixed_input_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%arg0, %[[RECV_OUTPUT]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%arg0, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// CHECK-LABEL: func @remapped_results
|
||||
func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2:2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
|
||||
}) {cluster_attr = "cluster_attr"} : () -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %2#1 : tensor<?xi32>
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a multiple outside compiled clusters with single device->host input.
|
||||
|
||||
// CHECK-LABEL: func @single_outside_compiled_input_multiple_outside_compilation
|
||||
func @single_outside_compiled_input_multiple_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_2:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_2:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_2:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_2]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT_2]])
|
||||
// CHECK: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT_1:[a-z_0-9]*]], %[[PROGRAM_OUTPUT_1:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT_1:[0-9]*]] = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT_1]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.B"(%[[RECV_OUTPUT_1]])
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: %[[C_OUTPUT:[0-9]*]] = "tf.C"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[C_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster2"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
"tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
%4 = "tf.C"() : () -> tensor<?xi32>
|
||||
"tf.D"(%4) {_xla_outside_compilation = "cluster2"} : (tensor<?xi32>) -> ()
|
||||
tf_device.return %4 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests extraction of a single outside compiled cluster with multiple device->host inputs.
|
||||
|
||||
// CHECK-LABEL: func @multiple_outside_compiled_inputs_single_outside_compilation
|
||||
func @multiple_outside_compiled_inputs_single_outside_compilation(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]] = "tf_device.parallel_execute"
|
||||
// CHECK-NEXT: "tf_device.launch"
|
||||
// CHECK: %[[STATUS_OUTPUT:[a-z_0-9]*]], %[[PROGRAM_OUTPUT:[a-z_0-9]*]] = "tf._TPUCompileMlir"
|
||||
// CHECK: %[[RECV_OUTPUT:[0-9]*]]:2 = "tf._XlaRecvAtHost"(%[[PROGRAM_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
// CHECK: "tf.C"(%[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf.D"(%[[RECV_OUTPUT]]#1, %[[RECV_OUTPUT]]#0)
|
||||
// CHECK: "tf_device.cluster"
|
||||
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"
|
||||
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"
|
||||
// CHECK: "tf._HostComputeMlir"(%[[A_OUTPUT]], %[[B_OUTPUT]])
|
||||
// CHECK-SAME: key = "host_compute_channel_cluster1"
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"() : () -> (tensor<?xi32>)
|
||||
"tf.C"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> ()
|
||||
"tf.D"(%4, %3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>, tensor<?xi32>) -> ()
|
||||
%5 = "tf.E"() : () -> tensor<?xi32>
|
||||
tf_device.return %5 : tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<?xi32>
|
||||
tf_device.return %2 : tensor<?xi32>
|
||||
}
|
||||
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
||||
// Tests only directly used results of tpu cluster are remapped with
|
||||
// parallel_execute.
|
||||
|
||||
// CHECK-LABEL: func @remapped_results
|
||||
func @remapped_results(%arg0: tensor<?xi32>) -> tensor<?xi32> {
|
||||
%0 = "tf.A"(%arg0) : (tensor<?xi32>) -> tensor<?xi32>
|
||||
// CHECK: %[[REPLICATE:[0-9]*]]:2 = tf_device.replicate
|
||||
// CHECK: %[[PARALLEL_EXECUTE_OUTPUT:[0-9]*]]:2 = "tf_device.parallel_execute"
|
||||
// CHECK: tf_device.return %[[PARALLEL_EXECUTE_OUTPUT]]#1 : tensor<?xi32>
|
||||
%1:2 = tf_device.replicate([%0, %arg0] as %ri_0: tensor<?xi32>) {n = 2 : i32} {
|
||||
%2:2 = "tf_device.cluster"() ( {
|
||||
%3 = "tf.A"() : () -> (tensor<?xi32>)
|
||||
%4 = "tf.B"(%3) {_xla_outside_compilation = "cluster1"} : (tensor<?xi32>) -> (tensor<?xi32>)
|
||||
%5:2 = "tf.C"(%4) : (tensor<?xi32>) -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %5#0, %5#1 : tensor<?xi32>, tensor<?xi32>
|
||||
}) {num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> (tensor<?xi32>, tensor<?xi32>)
|
||||
tf_device.return %2#1 : tensor<?xi32>
|
||||
}
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
return %1 : tensor<?xi32>
|
||||
}
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
// not have ops outside of the cluster that are both operands and results of the
|
||||
// cluster. Note, this currently does not handle side effecting ops yet.
|
||||
|
||||
#include <algorithm>
|
||||
#include <iterator>
|
||||
#include <memory>
|
||||
#include <tuple>
|
||||
@ -29,6 +30,7 @@ limitations under the License.
|
||||
|
||||
#include "llvm/ADT/ArrayRef.h"
|
||||
#include "llvm/ADT/DenseMap.h"
|
||||
#include "llvm/ADT/DenseSet.h"
|
||||
#include "llvm/ADT/STLExtras.h"
|
||||
#include "llvm/ADT/SetVector.h"
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
@ -59,6 +61,7 @@ constexpr char kTPUReplicateAttr[] = "_tpu_replicate";
|
||||
constexpr char kDeviceAttr[] = "device";
|
||||
constexpr char kNameAttr[] = "name";
|
||||
constexpr char kNumReplicasAttr[] = "num_replicas";
|
||||
constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
|
||||
constexpr char kMirroredVariableIndicesAttr[] = "_mirrored_variable_indices";
|
||||
|
||||
constexpr char kBadTPUReplicateAttrMsg[] =
|
||||
@ -261,33 +264,42 @@ void MovePrecedingClusterUsers(tf_device::ClusterOp cluster,
|
||||
|
||||
// Sorts `tf.TPUReplicatedInput` ops by `index` attribute. Ops with an `index`
|
||||
// of -1 are always after ops with a non negative `index`, and an arbitrary
|
||||
// ordering is used as there are no dependencies on their relative ordering.
|
||||
// ordering is used as there are no dependencies on their relative ordering. If
|
||||
// there are multiple `tf.TPUReplicatedInput` ops with the same non negative
|
||||
// index or if indices are less than -1, an error will be returned.
|
||||
LogicalResult SortTPUReplicatedInputsByIndex(
|
||||
llvm::ArrayRef<Operation*> inputs,
|
||||
llvm::SmallVectorImpl<Operation*>* sorted_inputs) {
|
||||
const int input_size = inputs.size();
|
||||
sorted_inputs->resize(input_size, nullptr);
|
||||
int last_index = input_size - 1;
|
||||
|
||||
llvm::SmallDenseSet<int64_t, 8> unique_indices;
|
||||
for (Operation* input : inputs) {
|
||||
int64_t index =
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getLimitedValue();
|
||||
|
||||
if (index >= input_size || index < -1)
|
||||
return input->emitError() << "'" << input->getName().getStringRef()
|
||||
<< "' index is not in range [-1, " << input_size
|
||||
<< "), got " << index;
|
||||
|
||||
if (index == -1)
|
||||
(*sorted_inputs)[last_index--] = input;
|
||||
else
|
||||
(*sorted_inputs)[index] = input;
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(input).index().getSExtValue();
|
||||
if (index < -1)
|
||||
return input->emitOpError()
|
||||
<< "requires index to be at least -1, but got " << index;
|
||||
if (index == -1) continue;
|
||||
if (!unique_indices.insert(index).second)
|
||||
return input->emitOpError()
|
||||
<< "requires indices to be unique, but found multiple '"
|
||||
<< input->getName() << "' ops with index " << index;
|
||||
}
|
||||
|
||||
if (llvm::any_of(*sorted_inputs, [](Operation* op) { return op == nullptr; }))
|
||||
return inputs.front()->emitError()
|
||||
<< "failed to sort '" << inputs.front()->getName().getStringRef()
|
||||
<< "' ops, gap(s) found in indices";
|
||||
// Sort all TPUReplicatedInputs by `index` attribute to have
|
||||
// TPUReplicatedInputs with indices be added to the `tf_device.replicate` op
|
||||
// deterministically. If `index` attribute is -1, instead move them to the
|
||||
// end.
|
||||
sorted_inputs->assign(inputs.begin(), inputs.end());
|
||||
std::stable_sort(
|
||||
sorted_inputs->begin(), sorted_inputs->end(),
|
||||
[](Operation* l, Operation* r) {
|
||||
int64_t l_index =
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(l).index().getSExtValue();
|
||||
int64_t r_index =
|
||||
llvm::cast<TF::TPUReplicatedInputOp>(r).index().getSExtValue();
|
||||
if (l_index == -1 && r_index != -1) return false;
|
||||
if (r_index == -1 && l_index != -1) return true;
|
||||
return l_index < r_index;
|
||||
});
|
||||
|
||||
return success();
|
||||
}
|
||||
@ -315,6 +327,11 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
unique_replicated_input_ops.getArrayRef(), &replicated_input_ops)))
|
||||
return failure();
|
||||
|
||||
// Index attribute value stored on TPUReplicatedInput op. These will be used
|
||||
// later for dynamic padder.
|
||||
llvm::SmallVector<int64_t, 8> replicated_input_indices;
|
||||
bool has_replicated_input_index = false;
|
||||
|
||||
// Indices of the replicate op's arguments that are mirrored variables.
|
||||
llvm::SmallVector<int64_t, 8> mirrored_variable_indices;
|
||||
|
||||
@ -330,7 +347,14 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
|
||||
replicated_inputs.push_back(
|
||||
{input->getOperands(), input->getOperand(0).getType()});
|
||||
if (llvm::cast<TF::TPUReplicatedInputOp>(input).is_mirrored_variable())
|
||||
|
||||
auto tpu_replicated_input = llvm::cast<TF::TPUReplicatedInputOp>(input);
|
||||
int64_t tpu_replicated_input_index =
|
||||
tpu_replicated_input.index().getSExtValue();
|
||||
replicated_input_indices.push_back(tpu_replicated_input_index);
|
||||
if (tpu_replicated_input_index != -1) has_replicated_input_index = true;
|
||||
|
||||
if (tpu_replicated_input.is_mirrored_variable())
|
||||
mirrored_variable_indices.push_back(pos_and_input.index());
|
||||
}
|
||||
|
||||
@ -340,6 +364,10 @@ LogicalResult ReplicateCluster(tf_device::ClusterOp cluster, int num_replicas) {
|
||||
cluster.getLoc(), num_replicas,
|
||||
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<StringRef, 4>>(),
|
||||
replicated_inputs, cluster.getResultTypes());
|
||||
if (has_replicated_input_index)
|
||||
replicate_op.setAttr(kReplicatedInputIndicesAttr,
|
||||
builder.getI64ArrayAttr(replicated_input_indices));
|
||||
|
||||
if (!mirrored_variable_indices.empty())
|
||||
replicate_op.setAttr(kMirroredVariableIndicesAttr,
|
||||
builder.getI64ArrayAttr(mirrored_variable_indices));
|
||||
|
@ -40,6 +40,7 @@ limitations under the License.
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
|
||||
constexpr char kReplicatedInputIndicesAttr[] = "_replicated_input_indices";
|
||||
constexpr char kPaddingMapAttr[] = "padding_map";
|
||||
|
||||
// This pass remaps and assigns padding maps to an encapsulated function's
|
||||
@ -56,14 +57,23 @@ struct TPUDynamicPaddingMapper
|
||||
// Creates a mapping from replicated input index (in `tf_device.replicate` op)
|
||||
// to `tf_device.cluster_func` operand index.
|
||||
llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
|
||||
tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate) {
|
||||
tf_device::ClusterFuncOp cluster_func, tf_device::ReplicateOp replicate,
|
||||
ArrayAttr replicated_input_indices_attr) {
|
||||
Block* replicate_block = &replicate.GetBody();
|
||||
|
||||
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices;
|
||||
for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands()))
|
||||
if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>())
|
||||
if (block_arg.getOwner() == replicate_block)
|
||||
remapped_indices[block_arg.getArgNumber()] = operand_and_idx.index();
|
||||
for (auto operand_and_idx : llvm::enumerate(cluster_func.getOperands())) {
|
||||
if (auto block_arg = operand_and_idx.value().dyn_cast<BlockArgument>()) {
|
||||
if (block_arg.getOwner() == replicate_block) {
|
||||
int64_t replicated_input_index =
|
||||
replicated_input_indices_attr[block_arg.getArgNumber()]
|
||||
.cast<IntegerAttr>()
|
||||
.getInt();
|
||||
if (replicated_input_index != -1)
|
||||
remapped_indices[replicated_input_index] = operand_and_idx.index();
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return remapped_indices;
|
||||
}
|
||||
@ -73,16 +83,15 @@ llvm::SmallDenseMap<int32_t, int32_t> GetRemappedReplicatedInputIndices(
|
||||
// indices. An error will be returned if an index is not found or parsing
|
||||
// failed.
|
||||
LogicalResult GetRemappedPaddings(
|
||||
tf_device::ClusterFuncOp cluster_func, int num_replicated_args,
|
||||
tf_device::ClusterFuncOp cluster_func,
|
||||
const llvm::SmallDenseMap<int32_t, int32_t>& remapped_indices,
|
||||
llvm::SmallVectorImpl<tensorflow::tpu::PaddingMap>* remapped_paddings) {
|
||||
auto bad_index_msg = [num_replicated_args](int32_t index,
|
||||
llvm::StringRef arg_type,
|
||||
int32_t arg_index) {
|
||||
auto bad_index_msg = [](int32_t index, llvm::StringRef arg_type,
|
||||
int32_t arg_index) {
|
||||
return llvm::formatv(
|
||||
"bad '{0}' attribute at index {1}, {2} must be in [0, {3}), got "
|
||||
"{4}",
|
||||
kPaddingMapAttr, index, arg_type, num_replicated_args, arg_index)
|
||||
"bad '{0}' attribute at index {1}, {2} must be nonnegative, but "
|
||||
"got {3}",
|
||||
kPaddingMapAttr, index, arg_type, arg_index)
|
||||
.str();
|
||||
};
|
||||
|
||||
@ -111,12 +120,12 @@ LogicalResult GetRemappedPaddings(
|
||||
kPaddingMapAttr, idx, padding.getValue()));
|
||||
|
||||
const int32_t arg_index = padding_proto.arg_index();
|
||||
if (arg_index >= num_replicated_args || arg_index < 0)
|
||||
if (arg_index < 0)
|
||||
return cluster_func.emitOpError()
|
||||
<< bad_index_msg(idx, "arg_index", arg_index);
|
||||
|
||||
const int32_t padding_arg_index = padding_proto.padding_arg_index();
|
||||
if (padding_arg_index >= num_replicated_args || padding_arg_index < 0)
|
||||
if (padding_arg_index < 0)
|
||||
return cluster_func.emitOpError()
|
||||
<< bad_index_msg(idx, "padding_arg_index", padding_arg_index);
|
||||
|
||||
@ -175,17 +184,21 @@ LogicalResult RemapAndAssignPaddingMaps(tf_device::ClusterFuncOp cluster_func,
|
||||
auto replicate = cluster_func.getParentOfType<tf_device::ReplicateOp>();
|
||||
// LaunchFunc is not replicated, there will be no padding.
|
||||
if (!replicate) return success();
|
||||
const int num_replicated_args = replicate.GetBody().getNumArguments();
|
||||
|
||||
auto func = symbol_table->lookup<FuncOp>(cluster_func.func());
|
||||
if (!func) return success();
|
||||
|
||||
auto replicated_input_indices_attr =
|
||||
replicate.getAttrOfType<ArrayAttr>(kReplicatedInputIndicesAttr);
|
||||
if (!replicated_input_indices_attr) return success();
|
||||
|
||||
llvm::SmallDenseMap<int32_t, int32_t> remapped_indices =
|
||||
GetRemappedReplicatedInputIndices(cluster_func, replicate);
|
||||
GetRemappedReplicatedInputIndices(cluster_func, replicate,
|
||||
replicated_input_indices_attr);
|
||||
|
||||
llvm::SmallVector<tensorflow::tpu::PaddingMap, 4> remapped_paddings;
|
||||
if (failed(GetRemappedPaddings(cluster_func, num_replicated_args,
|
||||
remapped_indices, &remapped_paddings)))
|
||||
if (failed(GetRemappedPaddings(cluster_func, remapped_indices,
|
||||
&remapped_paddings)))
|
||||
return failure();
|
||||
|
||||
AnnotateFunctionArgumentsWithPaddings(func, remapped_paddings);
|
||||
|
@ -26,6 +26,8 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/tpu_rewrite_device_util.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace TFTPU {
|
||||
@ -91,13 +93,14 @@ void MoveOutsideClusterOpsToLaunchOp(tf_device::LaunchOp launch_op,
|
||||
|
||||
// Creates a `tf_device::LaunchOp` to wrap cluster ops.
|
||||
tf_device::LaunchOp CreateLaunchOpForOutsideCluster(
|
||||
OpBuilder* builder, Operation* last_cluster_op) {
|
||||
// TODO(b/154363171): Set the CPU device.
|
||||
OpBuilder* builder, Operation* last_cluster_op,
|
||||
llvm::StringRef host_device) {
|
||||
// An empty string placeholder is used for the device as that will be later
|
||||
// populated with the device of the associated TPUReplicateMetadata op.
|
||||
llvm::SmallVector<Type, 8> result_types;
|
||||
auto launch_op = builder->create<tf_device::LaunchOp>(
|
||||
last_cluster_op->getLoc(), builder->getStringAttr(""), result_types);
|
||||
last_cluster_op->getLoc(), builder->getStringAttr(host_device),
|
||||
result_types);
|
||||
|
||||
launch_op.body().push_back(new Block);
|
||||
|
||||
@ -253,8 +256,9 @@ void MoveOutsideCompiledOps(
|
||||
|
||||
// Creates a `parallel_execute` op in place of launch with 'clusters` and
|
||||
// 'launch` as regions.
|
||||
void CreateParallelExecuteFromOutsideClusters(
|
||||
tf_device::ClusterOp tpu_cluster, const OutsideClusterMap& clusters) {
|
||||
void CreateParallelExecuteFromOutsideClusters(tf_device::ClusterOp tpu_cluster,
|
||||
const OutsideClusterMap& clusters,
|
||||
llvm::StringRef host_device) {
|
||||
OpBuilder builder(tpu_cluster);
|
||||
// Create parallel_execute regions. The original TPU cluster computation
|
||||
// is the extra region.
|
||||
@ -269,8 +273,8 @@ void CreateParallelExecuteFromOutsideClusters(
|
||||
Block& outside_block =
|
||||
parallel_execute_op.GetRegionBlockWithIndex(cluster.index());
|
||||
builder.setInsertionPointToEnd(&outside_block);
|
||||
tf_device::LaunchOp host_launch_op =
|
||||
CreateLaunchOpForOutsideCluster(&builder, cluster_ops.back());
|
||||
tf_device::LaunchOp host_launch_op = CreateLaunchOpForOutsideCluster(
|
||||
&builder, cluster_ops.back(), host_device);
|
||||
|
||||
// Determine if there are any inputs that are provided out of cluster.
|
||||
auto external_inputs = GetExternalOperands(cluster_ops);
|
||||
@ -307,8 +311,14 @@ void CreateParallelExecuteFromOutsideClusters(
|
||||
}
|
||||
|
||||
void TPUExtractOutsideCompilation::runOnOperation() {
|
||||
// Get runtime devices information from the closest parent module.
|
||||
auto module = getOperation();
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
if (failed(tensorflow::GetDevicesFromOp(module, &devices)))
|
||||
return signalPassFailure();
|
||||
|
||||
auto extract_result =
|
||||
getOperation().walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||
module.walk([&](tf_device::ClusterOp tpu_cluster) {
|
||||
OutsideClusterMap clusters;
|
||||
if (failed(CollectAndGroupOutsideClusterOps(&tpu_cluster.GetBody(),
|
||||
&clusters)))
|
||||
@ -316,7 +326,11 @@ void TPUExtractOutsideCompilation::runOnOperation() {
|
||||
|
||||
if (clusters.empty()) return WalkResult::advance();
|
||||
|
||||
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters);
|
||||
std::string host_device;
|
||||
tensorflow::GetHostDeviceOutsideComputation(devices, tpu_cluster,
|
||||
&host_device);
|
||||
CreateParallelExecuteFromOutsideClusters(tpu_cluster, clusters,
|
||||
host_device);
|
||||
|
||||
return WalkResult::advance();
|
||||
});
|
||||
|
@ -91,6 +91,7 @@ class PyClient : public std::enable_shared_from_this<PyClient> {
|
||||
explicit PyClient(std::shared_ptr<PjRtClient> pjrt_client);
|
||||
|
||||
PjRtClient* pjrt_client() const { return pjrt_client_.get(); }
|
||||
std::shared_ptr<PjRtClient> shared_pjrt_client() { return pjrt_client_; }
|
||||
|
||||
const std::string& platform_name() const {
|
||||
return pjrt_client_->platform_name();
|
||||
|
@ -34,13 +34,13 @@ END
|
||||
description: <<END
|
||||
The type of padding algorithm to use.
|
||||
|
||||
We specify the size-related attributes as:
|
||||
The size-related attributes are specified as follows:
|
||||
|
||||
```python
|
||||
ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
|
||||
strides = [1, stride_planes, strides_rows, strides_cols, 1]
|
||||
ksizes = [1, ksize_planes, ksize_rows, ksize_cols, 1]
|
||||
strides = [1, stride_planes, strides_rows, strides_cols, 1]
|
||||
```
|
||||
END
|
||||
}
|
||||
summary: "Extract `patches` from `input` and put them in the \"depth\" output dimension. 3D extension of `extract_image_patches`."
|
||||
summary: "Extract `patches` from `input` and put them in the `\"depth\"` output dimension. 3D extension of `extract_image_patches`."
|
||||
}
|
||||
|
@ -9,10 +9,9 @@ def _lookup_file(filegroup, path):
|
||||
return file
|
||||
return None
|
||||
|
||||
def _gen_kernel_image_hdr_impl(ctx):
|
||||
if not ctx.attr.gpu_archs:
|
||||
fail("No GPU architecture specified, use --config=cuda or similar")
|
||||
CubinInfo = provider(fields = ["cubins"])
|
||||
|
||||
def _gen_kernel_cubin_impl(ctx):
|
||||
name = ctx.attr.name
|
||||
tile_sizes = ctx.attr.tile_size.replace("x", ",")
|
||||
cmd_args = []
|
||||
@ -22,7 +21,6 @@ def _gen_kernel_image_hdr_impl(ctx):
|
||||
cmd_args.append("--unroll_factors=%s" % ctx.attr.unroll_factors)
|
||||
|
||||
cubins = []
|
||||
images = []
|
||||
for arch in ctx.attr.gpu_archs:
|
||||
# TODO(b/152737872): 'compute_' should generate both SASS and PTX.
|
||||
arch = arch.replace("compute_", "sm_")
|
||||
@ -41,13 +39,36 @@ def _gen_kernel_image_hdr_impl(ctx):
|
||||
mnemonic = "compile",
|
||||
)
|
||||
cubins.append(cubin)
|
||||
return [CubinInfo(cubins = cubins)]
|
||||
|
||||
_gen_kernel_cubin_rule = rule(
|
||||
implementation = _gen_kernel_cubin_impl,
|
||||
attrs = {
|
||||
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
|
||||
"tile_size": attr.string(mandatory = True),
|
||||
"same_shape": attr.string(),
|
||||
"unroll_factors": attr.string(),
|
||||
"gpu_archs": attr.string_list(mandatory = True),
|
||||
"_tool": attr.label(
|
||||
executable = True,
|
||||
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
|
||||
cfg = "host",
|
||||
),
|
||||
},
|
||||
output_to_genfiles = True,
|
||||
)
|
||||
|
||||
def _gen_kernel_image_hdr_impl(ctx):
|
||||
images = []
|
||||
for cubin in ctx.attr.input[CubinInfo].cubins:
|
||||
arch = cubin.path.split(".")[-2]
|
||||
images.append("--image=profile=%s,file=%s" % (arch, cubin.path))
|
||||
|
||||
# Generate fatbin file from all cubins.
|
||||
fatbin = ctx.actions.declare_file("%s.fatbin" % name)
|
||||
fatbin = ctx.actions.declare_file("%s.fatbin" % ctx.attr.name)
|
||||
ctx.actions.run(
|
||||
outputs = [fatbin],
|
||||
inputs = cubins,
|
||||
inputs = ctx.attr.input[CubinInfo].cubins,
|
||||
executable = _lookup_file(ctx.attr._cuda_root, "bin/fatbinary"),
|
||||
arguments = [
|
||||
"--64",
|
||||
@ -73,37 +94,31 @@ _gen_kernel_image_hdr_rule = rule(
|
||||
implementation = _gen_kernel_image_hdr_impl,
|
||||
output_to_genfiles = True,
|
||||
attrs = {
|
||||
"mlir_op": attr.label(mandatory = True, allow_single_file = True),
|
||||
"tile_size": attr.string(mandatory = True),
|
||||
"same_shape": attr.string(),
|
||||
"unroll_factors": attr.string(),
|
||||
"input": attr.label(mandatory = True, providers = [CubinInfo]),
|
||||
"out": attr.output(mandatory = True),
|
||||
"symbol": attr.string(mandatory = True),
|
||||
"gpu_archs": attr.string_list(mandatory = True),
|
||||
"_cuda_root": attr.label(
|
||||
default = Label("@local_config_cuda//cuda:cuda_root"),
|
||||
),
|
||||
"_tool": attr.label(
|
||||
executable = True,
|
||||
default = Label("//tensorflow/compiler/mlir/tools/kernel_gen:tf_to_cubin"),
|
||||
cfg = "host",
|
||||
),
|
||||
},
|
||||
)
|
||||
|
||||
def _gen_kernel_image_hdr(name, mlir_op, tile_size, tags = [], same_shape = None, unroll_factors = None):
|
||||
def _gen_kernel_image_hdr(name, mlir_op, tile_size, same_shape = None, unroll_factors = None):
|
||||
"""Generates a C header with fatbin data from a Tensorflow op."""
|
||||
if cuda_gpu_architectures():
|
||||
_gen_kernel_image_hdr_rule(
|
||||
name = name,
|
||||
_gen_kernel_cubin_rule(
|
||||
name = name + "_cubin",
|
||||
mlir_op = mlir_op,
|
||||
tile_size = tile_size,
|
||||
same_shape = same_shape,
|
||||
unroll_factors = unroll_factors,
|
||||
gpu_archs = cuda_gpu_architectures(),
|
||||
)
|
||||
_gen_kernel_image_hdr_rule(
|
||||
name = name,
|
||||
input = ":" + name + "_cubin",
|
||||
out = "%s.h" % name,
|
||||
symbol = "k%s" % name.replace("_", " ").title().replace(" ", ""),
|
||||
gpu_archs = cuda_gpu_architectures(),
|
||||
tags = tags,
|
||||
)
|
||||
|
||||
def _gen_mlir_op_impl(ctx):
|
||||
@ -157,7 +172,6 @@ def gen_kernel_library(name, types, tile_size, tags = [], same_shape = None, unr
|
||||
name = "{name}_{type}_kernel".format(name = name, type = type),
|
||||
mlir_op = "{name}_{type}.mlir".format(name = name, type = type),
|
||||
tile_size = tile_size,
|
||||
tags = tags,
|
||||
same_shape = same_shape,
|
||||
unroll_factors = unroll_factors,
|
||||
)
|
||||
|
@ -108,7 +108,7 @@ limitations under the License.
|
||||
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_PRODUCER 0
|
||||
#define TF_GRAPH_DEF_VERSION_MIN_CONSUMER 0
|
||||
#define TF_GRAPH_DEF_VERSION 441 // Updated: 2020/6/23
|
||||
#define TF_GRAPH_DEF_VERSION 442 // Updated: 2020/6/24
|
||||
|
||||
// Checkpoint compatibility versions (the versions field in SavedSliceMeta).
|
||||
//
|
||||
|
@ -28,10 +28,12 @@ cc_library(
|
||||
srcs = ["tpu_compile_op_common.cc"],
|
||||
hdrs = ["tpu_compile_op_common.h"],
|
||||
deps = [
|
||||
":tpu_compile_op_options",
|
||||
":tpu_compile_op_support",
|
||||
":tpu_mesh_state_interface",
|
||||
":tpu_program_group_interface",
|
||||
":tpu_util",
|
||||
":tpu_util_c_api_hdrs",
|
||||
":tpu_util_hdrs",
|
||||
"//tensorflow/compiler/jit:flags",
|
||||
"//tensorflow/compiler/jit:shape_inference",
|
||||
@ -50,6 +52,7 @@ cc_library(
|
||||
"//tensorflow/core/tpu:tpu_configuration",
|
||||
"//tensorflow/core/tpu:tpu_defs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_interface",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
alwayslink = 1,
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "absl/strings/string_view.h"
|
||||
#include "tensorflow/compiler/jit/flags.h"
|
||||
#include "tensorflow/compiler/tf2xla/tf2xla_util.h"
|
||||
#include "tensorflow/compiler/xla/client/client_library.h"
|
||||
@ -28,8 +29,10 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/protobuf/tpu/compile_metadata.pb.h"
|
||||
#include "tensorflow/core/protobuf/tpu/dynamic_padding.pb.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_compile_op_options.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_group_interface.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/tpu_configuration.h"
|
||||
#include "tensorflow/core/tpu/tpu_defs.h"
|
||||
|
||||
@ -518,5 +521,41 @@ Status TpuCompileOpKernelCommon::OptimizeGraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void TpuCompileOpKernelCommon::Compute(OpKernelContext* ctx) {
|
||||
VLOG(1) << "Cloud TPU: TpuCompileOpKernelCommon::Compute";
|
||||
|
||||
std::shared_ptr<std::atomic<bool>> done(new std::atomic<bool>(false));
|
||||
|
||||
CancellationToken token =
|
||||
ctx->cancellation_manager()->get_cancellation_token();
|
||||
const bool already_cancelled =
|
||||
!ctx->cancellation_manager()->RegisterCallback(token, [ctx, done]() {
|
||||
if (TpuCompile_ShouldTpuCompileOpIgnoreCancellation()) {
|
||||
return;
|
||||
}
|
||||
|
||||
// Sleep and exit in another thread so the cancellation manager can
|
||||
// continue running callbacks.
|
||||
ctx->env()->SchedClosure([ctx, done]() { ExitCountdown(ctx, done); });
|
||||
});
|
||||
|
||||
// If the RPC was cancelled before we registered the cancellation callback,
|
||||
// don't compile the TPU program.
|
||||
OP_REQUIRES(ctx, !already_cancelled,
|
||||
errors::Cancelled("RPC cancelled, not compiling TPU program"));
|
||||
|
||||
// We only want to abort the process if a cancellation actually occurs during
|
||||
// compilation; we must deregister the callback in the success case. It
|
||||
// doesn't hurt to also deregister the callback in the failure case; the
|
||||
// CancellationManager ensures that already-registered callbacks will be run
|
||||
// once cancellation has started.
|
||||
auto cancellation_cleanup = xla::MakeCleanup([ctx, token, done] {
|
||||
ctx->cancellation_manager()->DeregisterCallback(token);
|
||||
done->store(true);
|
||||
});
|
||||
|
||||
OP_REQUIRES_OK(ctx, ComputeInternal(ctx));
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -53,7 +53,8 @@ class TpuCompileOpKernelCommon {
|
||||
|
||||
virtual ~TpuCompileOpKernelCommon() = default;
|
||||
|
||||
virtual void Compute(OpKernelContext* ctx) = 0;
|
||||
void Compute(OpKernelContext* ctx);
|
||||
virtual Status ComputeInternal(OpKernelContext* ctx) = 0;
|
||||
|
||||
// Computes shapes for each argument. Uses both the static shape from the
|
||||
// metadata, and the dynamic shapes where the static shape is not
|
||||
|
@ -95,6 +95,5 @@ Status DynamicShapesToTensorShapes(const InputList& dynamic_shapes,
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
@ -68,7 +68,6 @@ Status TpuPaddedShapeFn(const Tensor& tensor, xla::Shape* shape);
|
||||
|
||||
// A callback called on exit.
|
||||
void LogAndExit(int code);
|
||||
|
||||
} // namespace tpu
|
||||
} // namespace tensorflow
|
||||
|
||||
|
@ -31,6 +31,12 @@ void TpuCompile_ToTpuShapeRepresentation(
|
||||
bool use_fast_memory, TpuSerializedProto* serialized_tensor_shape,
|
||||
SE_Status* status);
|
||||
|
||||
// XLA compilation cannot be cancelled. To avoid hanging the TF worker will exit
|
||||
// when cancellation is requested for an XLA compile op. Some tests require this
|
||||
// behavior to be disabled, and we test for this condition with the following
|
||||
// flag function.
|
||||
bool TpuCompile_ShouldTpuCompileOpIgnoreCancellation();
|
||||
|
||||
} // extern "C"
|
||||
|
||||
struct TfTpu_UtilApiFn {
|
||||
|
@ -26,6 +26,7 @@ import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
"io"
|
||||
"math/bits"
|
||||
"reflect"
|
||||
"runtime"
|
||||
"unsafe"
|
||||
@ -80,7 +81,7 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
if dataType == String {
|
||||
// TF_STRING tensors are encoded as an array of 8-byte offsets
|
||||
// followed by string data. See c_api.h.
|
||||
nbytes = uintptr(nflattened*8) + byteSizeOfEncodedStrings(value)
|
||||
nbytes = uintptr(nflattened*8 + int64(byteSizeOfEncodedStrings(val)))
|
||||
}
|
||||
var shapePtr *C.int64_t
|
||||
if len(shape) > 0 {
|
||||
@ -94,9 +95,22 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
raw := tensorData(t.c)
|
||||
buf := bytes.NewBuffer(raw[:0:len(raw)])
|
||||
if dataType != String {
|
||||
if err := encodeTensor(buf, val, shape); err != nil {
|
||||
return nil, err
|
||||
if isAllArray(val.Type()) {
|
||||
// We have arrays all the way down, or just primitive types. We can
|
||||
// just copy the memory in as it is all contiguous.
|
||||
if err := copyPtr(buf, unpackEFace(value).data, int(val.Type().Size())); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else {
|
||||
// When there are slices involved the memory for each leaf slice may
|
||||
// not be contiguous with the others or in the order we might
|
||||
// expect, so we need to work our way down to each slice of
|
||||
// primitives and copy them individually
|
||||
if err := encodeTensorWithSlices(buf, val, shape); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
if uintptr(buf.Len()) != nbytes {
|
||||
return nil, bug("NewTensor incorrectly calculated the size of a tensor with type %v and shape %v as %v bytes instead of %v", dataType, shape, nbytes, buf.Len())
|
||||
}
|
||||
@ -112,6 +126,43 @@ func NewTensor(value interface{}) (*Tensor, error) {
|
||||
return t, nil
|
||||
}
|
||||
|
||||
// isAllArray returns true if type is a primitive type or an array of primitive
|
||||
// types or an array of ... etc.. When this is true the data we want is
|
||||
// contiguous in RAM.
|
||||
func isAllArray(typ reflect.Type) bool {
|
||||
switch typ.Kind() {
|
||||
case reflect.Slice:
|
||||
return false
|
||||
case reflect.Array:
|
||||
return isAllArray(typ.Elem())
|
||||
default:
|
||||
// We know the type is slices/arrays of slices/arrays of primitive types.
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
// eface defines what an interface type actually is: a pointer to type
|
||||
// information about the encapsulated type and a pointer to the encapsulated
|
||||
// value.
|
||||
type eface struct {
|
||||
rtype unsafe.Pointer
|
||||
data unsafe.Pointer
|
||||
}
|
||||
|
||||
// unpackEFace gives us an effient way to get us a pointer to the value carried
|
||||
// in an interface. If you wrap a pointer type in an interface then the pointer
|
||||
// is directly stored in the interface struct. If you wrap a value type in an
|
||||
// interface then the compiler copies the value into a newly allocated piece of
|
||||
// memory and stores a pointer to that memory in the interface. So we're
|
||||
// guaranteed to get a pointer. Go reflection doesn't expose the pointer to
|
||||
// value types straightforwardly as it doesn't want you to think you have a
|
||||
// reference to the original value. But we just want a pointer to make it
|
||||
// efficient to read the value, so cheating like this should be safe and
|
||||
// reasonable.
|
||||
func unpackEFace(obj interface{}) *eface {
|
||||
return (*eface)(unsafe.Pointer(&obj))
|
||||
}
|
||||
|
||||
// ReadTensor constructs a Tensor with the provided type and shape from the
|
||||
// serialized tensor contents in r.
|
||||
//
|
||||
@ -168,21 +219,152 @@ func (t *Tensor) Shape() []int64 { return t.shape }
|
||||
// Tensor(int64, 0): int64
|
||||
// Tensor(float64, 3): [][][]float64
|
||||
func (t *Tensor) Value() interface{} {
|
||||
typ := typeOf(t.DataType(), t.Shape())
|
||||
val := reflect.New(typ)
|
||||
raw := tensorData(t.c)
|
||||
if t.DataType() != String {
|
||||
if err := decodeTensor(bytes.NewReader(raw), t.Shape(), typ, val); err != nil {
|
||||
panic(bug("unable to decode Tensor of type %v and shape %v - %v", t.DataType(), t.Shape(), err))
|
||||
shape := t.Shape()
|
||||
dt := t.DataType()
|
||||
return decodeTensor(raw, shape, dt).Interface()
|
||||
}
|
||||
|
||||
func decodeTensor(raw []byte, shape []int64, dt DataType) reflect.Value {
|
||||
// Create a 1-dimensional slice of the base large enough for the data and
|
||||
// copy the data in.
|
||||
n := int(numElements(shape))
|
||||
|
||||
var (
|
||||
slice reflect.Value
|
||||
typ reflect.Type
|
||||
)
|
||||
if dt == String {
|
||||
strs, err := decodeOneDimString(raw, n)
|
||||
if err != nil {
|
||||
panic(bug("unable to decode string with shape %v: %v", shape, err))
|
||||
}
|
||||
slice = reflect.ValueOf(strs)
|
||||
typ = slice.Type()
|
||||
} else {
|
||||
nflattened := numElements(t.Shape())
|
||||
d := stringDecoder{offsets: bytes.NewReader(raw[0 : 8*nflattened]), data: raw[8*nflattened:], status: newStatus()}
|
||||
if err := d.decode(val, t.Shape()); err != nil {
|
||||
panic(bug("unable to decode String tensor with shape %v - %v", t.Shape(), err))
|
||||
}
|
||||
typ = typeForDataType(dt)
|
||||
l := n * int(typ.Size())
|
||||
typ = reflect.SliceOf(typ)
|
||||
slice = reflect.MakeSlice(typ, n, n)
|
||||
baseBytes := *(*[]byte)(unsafe.Pointer(&sliceHeader{
|
||||
Data: unsafe.Pointer(slice.Pointer()),
|
||||
Len: l,
|
||||
Cap: l,
|
||||
}))
|
||||
copy(baseBytes, raw)
|
||||
}
|
||||
return reflect.Indirect(val).Interface()
|
||||
|
||||
// Now we have the data in place in the base slice we can add the
|
||||
// dimensions. We want to walk backwards through the shape. If the shape is
|
||||
// length 1 or 0 then we're already done.
|
||||
if len(shape) == 0 {
|
||||
return slice.Index(0)
|
||||
}
|
||||
if len(shape) == 1 {
|
||||
return slice
|
||||
}
|
||||
// We have a special case if the tensor has no data. Our backing slice is
|
||||
// empty, but we still want to create slices following the shape. In this
|
||||
// case only the final part of the shape will be 0 and we want to recalculate
|
||||
// n at this point ignoring that 0.
|
||||
// For example if our shape is 3 * 2 * 0 then n will be zero, but we still
|
||||
// want 6 zero length slices to group as follows.
|
||||
// {{} {}} {{} {}} {{} {}}
|
||||
if n == 0 {
|
||||
n = int(numElements(shape[:len(shape)-1]))
|
||||
}
|
||||
for i := len(shape) - 2; i >= 0; i-- {
|
||||
underlyingSize := typ.Elem().Size()
|
||||
typ = reflect.SliceOf(typ)
|
||||
subsliceLen := int(shape[i+1])
|
||||
if subsliceLen != 0 {
|
||||
n = n / subsliceLen
|
||||
}
|
||||
// Just using reflection it is difficult to avoid unnecessary
|
||||
// allocations while setting up the sub-slices as the Slice function on
|
||||
// a slice Value allocates. So we end up doing pointer arithmetic!
|
||||
// Pointer() on a slice gives us access to the data backing the slice.
|
||||
// We insert slice headers directly into this data.
|
||||
data := unsafe.Pointer(slice.Pointer())
|
||||
nextSlice := reflect.MakeSlice(typ, n, n)
|
||||
|
||||
for j := 0; j < n; j++ {
|
||||
// This is equivalent to nSlice[j] = slice[j*subsliceLen: (j+1)*subsliceLen]
|
||||
setSliceInSlice(nextSlice, j, sliceHeader{
|
||||
Data: unsafe.Pointer(uintptr(data) + (uintptr(j*subsliceLen) * underlyingSize)),
|
||||
Len: subsliceLen,
|
||||
Cap: subsliceLen,
|
||||
})
|
||||
}
|
||||
|
||||
slice = nextSlice
|
||||
}
|
||||
return slice
|
||||
}
|
||||
|
||||
// setSliceInSlice sets slice[index] = content.
|
||||
func setSliceInSlice(slice reflect.Value, index int, content sliceHeader) {
|
||||
const sliceSize = unsafe.Sizeof(sliceHeader{})
|
||||
// We must cast slice.Pointer to uninptr & back again to avoid GC issues.
|
||||
// See https://github.com/google/go-cmp/issues/167#issuecomment-546093202
|
||||
*(*sliceHeader)(unsafe.Pointer(uintptr(unsafe.Pointer(slice.Pointer())) + (uintptr(index) * sliceSize))) = content
|
||||
}
|
||||
|
||||
// decodeOneDimString decodes a string tensor into a one-dimensional []string.
|
||||
func decodeOneDimString(raw []byte, nStrings int) ([]string, error) {
|
||||
// Start by making an array of all the strings
|
||||
strs := make([]string, nStrings)
|
||||
// The first nStrings * 8 bytes of raw are offsets into the second half of
|
||||
// the raw data. This second half is where the strings are encoded.
|
||||
offsets := (*(*[]int64)(unsafe.Pointer(&raw)))[:nStrings]
|
||||
|
||||
// Reset raw after the offsets. Now the offsets will work relative to raw
|
||||
raw = raw[nStrings*8:]
|
||||
// Next we work out the final length of the string data so we can copy the
|
||||
// good data out of raw (which is owned by the C tensor and won't be safe
|
||||
// to access if the tensor is freed)
|
||||
r := bytes.NewReader(raw)
|
||||
var totalLength int
|
||||
for _, offset := range offsets {
|
||||
// At each offset we should find a varint length of a string.
|
||||
// Errors here should mean the tensor is corrupt.
|
||||
if _, err := r.Seek(offset, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
totalLength += int(l)
|
||||
}
|
||||
|
||||
// Lets allocate a big buffer to carry our string data.
|
||||
stringData := make([]byte, 0, totalLength)
|
||||
// Now copy the string data across into our new buffer, keeping track of the
|
||||
// location of each string in the strs slice.
|
||||
var cursor int
|
||||
for i, offset := range offsets {
|
||||
// At each offset we should find a varint length. Read it
|
||||
if _, err := r.Seek(offset, io.SeekStart); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
l, err := binary.ReadUvarint(r)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// Then copy the actual string into our large buffer
|
||||
target := stringData[cursor : cursor+int(l)]
|
||||
if _, err := r.Read(target); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
// Track where this string data is.
|
||||
strs[i] = *(*string)(unsafe.Pointer(&target))
|
||||
cursor += int(l)
|
||||
}
|
||||
|
||||
// So now we have a big slice of strings
|
||||
return strs, nil
|
||||
}
|
||||
|
||||
// WriteContentsTo writes the serialized contents of t to w.
|
||||
@ -261,18 +443,18 @@ func shapeAndDataTypeOf(val reflect.Value) (shape []int64, dt DataType, err erro
|
||||
return shape, dt, fmt.Errorf("unsupported type %v", typ)
|
||||
}
|
||||
|
||||
// typeOf converts from a DataType and Shape to the equivalent Go type.
|
||||
func typeOf(dt DataType, shape []int64) reflect.Type {
|
||||
var ret reflect.Type
|
||||
func typeForDataType(dt DataType) reflect.Type {
|
||||
for _, t := range types {
|
||||
if dt == DataType(t.dataType) {
|
||||
ret = t.typ
|
||||
break
|
||||
return t.typ
|
||||
}
|
||||
}
|
||||
if ret == nil {
|
||||
panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
|
||||
}
|
||||
panic(bug("DataType %v is not supported (see https://www.tensorflow.org/code/tensorflow/core/framework/types.proto)", dt))
|
||||
}
|
||||
|
||||
// typeOf converts from a DataType and Shape to the equivalent Go type.
|
||||
func typeOf(dt DataType, shape []int64) reflect.Type {
|
||||
ret := typeForDataType(dt)
|
||||
for range shape {
|
||||
ret = reflect.SliceOf(ret)
|
||||
}
|
||||
@ -289,109 +471,93 @@ func numElements(shape []int64) int64 {
|
||||
|
||||
// byteSizeOfEncodedStrings returns the size of the encoded strings in val.
|
||||
// val MUST be a string, or a container (array/slice etc.) of strings.
|
||||
func byteSizeOfEncodedStrings(val interface{}) uintptr {
|
||||
if s, ok := val.(string); ok {
|
||||
return uintptr(C.TF_StringEncodedSize(C.size_t(len(s))))
|
||||
// Tensorflow encodes strings as the varint encoded length followed by the
|
||||
// string bytes. We could call into the C library to do this but cgo has a heavy
|
||||
// overhead. So we just do that calculation in Go
|
||||
func byteSizeOfEncodedStrings(val reflect.Value) int {
|
||||
if val.Kind() == reflect.String {
|
||||
return sizeVarUint(uint64(val.Len())) + val.Len()
|
||||
}
|
||||
if val.Kind() != reflect.Slice && val.Kind() != reflect.Array {
|
||||
panic(fmt.Sprintf("unexpected type %s", val.Type()))
|
||||
}
|
||||
// Otherwise must be an array or slice.
|
||||
var size uintptr
|
||||
v := reflect.ValueOf(val)
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
size += byteSizeOfEncodedStrings(v.Index(i).Interface())
|
||||
var size int
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
size += byteSizeOfEncodedStrings(val.Index(i))
|
||||
}
|
||||
return size
|
||||
}
|
||||
|
||||
// encodeTensor writes v to the specified buffer using the format specified in
|
||||
// c_api.h. Use stringEncoder for String tensors.
|
||||
func encodeTensor(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||
switch v.Kind() {
|
||||
case reflect.Bool:
|
||||
b := byte(0)
|
||||
if v.Bool() {
|
||||
b = 1
|
||||
}
|
||||
if err := w.WriteByte(b); err != nil {
|
||||
return err
|
||||
}
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
if err := binary.Write(w, nativeEndian, v.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case reflect.Array, reflect.Slice:
|
||||
// If current dimension is a slice, verify that it has the expected size
|
||||
// Go's type system makes that guarantee for arrays.
|
||||
if v.Kind() == reflect.Slice {
|
||||
expected := int(shape[0])
|
||||
if v.Len() != expected {
|
||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||
}
|
||||
}
|
||||
|
||||
// Optimisation: if only one dimension is left we can use binary.Write() directly for this slice
|
||||
if len(shape) == 1 && v.Len() > 0 {
|
||||
switch v.Index(0).Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
return binary.Write(w, nativeEndian, v.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
subShape := shape[1:]
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
err := encodeTensor(w, v.Index(i), subShape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported type %v", v.Type())
|
||||
// sizeVarUint determines how many bytes it would take to encode the int v as
|
||||
// an unsigned varint
|
||||
func sizeVarUint(v uint64) int {
|
||||
if v < 0x80 {
|
||||
return 1
|
||||
}
|
||||
return nil
|
||||
bits := bits.Len64(v)
|
||||
return (bits + 6) / 7
|
||||
}
|
||||
|
||||
// decodeTensor decodes the Tensor from the buffer to ptr using the format
|
||||
// specified in c_api.h. Use stringDecoder for String tensors.
|
||||
func decodeTensor(r *bytes.Reader, shape []int64, typ reflect.Type, ptr reflect.Value) error {
|
||||
switch typ.Kind() {
|
||||
case reflect.Bool:
|
||||
b, err := r.ReadByte()
|
||||
// encodeTensorWithSlices writes v to the specified buffer using the format specified in
|
||||
// c_api.h. Use stringEncoder for String tensors.
|
||||
func encodeTensorWithSlices(w *bytes.Buffer, v reflect.Value, shape []int64) error {
|
||||
// If current dimension is a slice, verify that it has the expected size
|
||||
// Go's type system makes that guarantee for arrays.
|
||||
if v.Kind() == reflect.Slice {
|
||||
expected := int(shape[0])
|
||||
if v.Len() != expected {
|
||||
return fmt.Errorf("mismatched slice lengths: %d and %d", v.Len(), expected)
|
||||
}
|
||||
} else if v.Kind() != reflect.Array {
|
||||
return fmt.Errorf("unsupported type %v", v.Type())
|
||||
}
|
||||
|
||||
// Once we have just a single dimension we can just copy the data
|
||||
if len(shape) == 1 && v.Len() > 0 {
|
||||
elt := v.Index(0)
|
||||
if !elt.CanAddr() {
|
||||
panic("cannot take address")
|
||||
}
|
||||
ptr := unsafe.Pointer(elt.Addr().Pointer())
|
||||
return copyPtr(w, ptr, v.Len()*int(elt.Type().Size()))
|
||||
}
|
||||
|
||||
subShape := shape[1:]
|
||||
for i := 0; i < v.Len(); i++ {
|
||||
err := encodeTensorWithSlices(w, v.Index(i), subShape)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
ptr.Elem().SetBool(b == 1)
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
if err := binary.Read(r, nativeEndian, ptr.Interface()); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
case reflect.Slice:
|
||||
val := reflect.Indirect(ptr)
|
||||
val.Set(reflect.MakeSlice(typ, int(shape[0]), int(shape[0])))
|
||||
|
||||
// Optimization: if only one dimension is left we can use binary.Read() directly for this slice
|
||||
if len(shape) == 1 && val.Len() > 0 {
|
||||
switch val.Index(0).Kind() {
|
||||
case reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Float32, reflect.Float64, reflect.Complex64, reflect.Complex128:
|
||||
return binary.Read(r, nativeEndian, val.Interface())
|
||||
}
|
||||
}
|
||||
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
if err := decodeTensor(r, shape[1:], typ.Elem(), val.Index(i).Addr()); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
default:
|
||||
return fmt.Errorf("unsupported type %v", typ)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// It isn't safe to use reflect.SliceHeader as it uses a uintptr for Data and
|
||||
// this is not inspected by the garbage collector
|
||||
type sliceHeader struct {
|
||||
Data unsafe.Pointer
|
||||
Len int
|
||||
Cap int
|
||||
}
|
||||
|
||||
// copyPtr copies the backing data for a slice or array directly into w. Note
|
||||
// we don't need to worry about byte ordering because we want the natural byte
|
||||
// order for the machine we're running on.
|
||||
func copyPtr(w *bytes.Buffer, ptr unsafe.Pointer, l int) error {
|
||||
// Convert our slice header into a []byte so we can call w.Write
|
||||
b := *(*[]byte)(unsafe.Pointer(&sliceHeader{
|
||||
Data: ptr,
|
||||
Len: l,
|
||||
Cap: l,
|
||||
}))
|
||||
_, err := w.Write(b)
|
||||
return err
|
||||
}
|
||||
|
||||
type stringEncoder struct {
|
||||
offsets io.Writer
|
||||
offsets *bytes.Buffer
|
||||
data []byte
|
||||
offset uint64
|
||||
status *status
|
||||
@ -399,19 +565,18 @@ type stringEncoder struct {
|
||||
|
||||
func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
|
||||
if v.Kind() == reflect.String {
|
||||
if err := binary.Write(e.offsets, nativeEndian, e.offset); err != nil {
|
||||
if err := copyPtr(e.offsets, unsafe.Pointer(&e.offset), int(unsafe.Sizeof(e.offset))); err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
s = v.Interface().(string)
|
||||
src = C.CString(s)
|
||||
srcLen = C.size_t(len(s))
|
||||
dst = (*C.char)(unsafe.Pointer(&e.data[e.offset]))
|
||||
dstLen = C.size_t(uint64(len(e.data)) - e.offset)
|
||||
)
|
||||
e.offset += uint64(C.TF_StringEncode(src, srcLen, dst, dstLen, e.status.c))
|
||||
C.free(unsafe.Pointer(src))
|
||||
return e.status.Err()
|
||||
// A string is encoded as the varint length followed by the string bytes.
|
||||
// We do this in Go to avoid the considerable overhead of a cgo call into
|
||||
// the tensorflow library
|
||||
s := v.String()
|
||||
n := binary.PutUvarint(e.data[e.offset:], uint64(len(s)))
|
||||
e.offset += uint64(n)
|
||||
n = copy(e.data[e.offset:], s)
|
||||
e.offset += uint64(n)
|
||||
return nil
|
||||
}
|
||||
|
||||
if v.Kind() == reflect.Slice {
|
||||
@ -430,45 +595,6 @@ func (e *stringEncoder) encode(v reflect.Value, shape []int64) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
type stringDecoder struct {
|
||||
offsets io.Reader
|
||||
data []byte
|
||||
status *status
|
||||
}
|
||||
|
||||
func (d *stringDecoder) decode(ptr reflect.Value, shape []int64) error {
|
||||
if len(shape) == 0 {
|
||||
var offset uint64
|
||||
if err := binary.Read(d.offsets, nativeEndian, &offset); err != nil {
|
||||
return err
|
||||
}
|
||||
var (
|
||||
src = (*C.char)(unsafe.Pointer(&d.data[offset]))
|
||||
srcLen = C.size_t(len(d.data)) - C.size_t(offset)
|
||||
dst *C.char
|
||||
dstLen C.size_t
|
||||
)
|
||||
if offset > uint64(len(d.data)) {
|
||||
return fmt.Errorf("invalid offsets in String Tensor")
|
||||
}
|
||||
C.TF_StringDecode(src, srcLen, &dst, &dstLen, d.status.c)
|
||||
if err := d.status.Err(); err != nil {
|
||||
return err
|
||||
}
|
||||
s := ptr.Interface().(*string)
|
||||
*s = C.GoStringN(dst, C.int(dstLen))
|
||||
return nil
|
||||
}
|
||||
val := reflect.Indirect(ptr)
|
||||
val.Set(reflect.MakeSlice(typeOf(String, shape), int(shape[0]), int(shape[0])))
|
||||
for i := 0; i < val.Len(); i++ {
|
||||
if err := d.decode(val.Index(i).Addr(), shape[1:]); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func bug(format string, args ...interface{}) error {
|
||||
return fmt.Errorf("BUG: Please report at https://github.com/tensorflow/tensorflow/issues with the note: Go TensorFlow %v: %v", Version(), fmt.Sprintf(format, args...))
|
||||
}
|
||||
@ -489,22 +615,3 @@ func isTensorSerializable(dataType DataType) error {
|
||||
return fmt.Errorf("serialization of tensors with the DataType %d is not yet supported, see https://github.com/tensorflow/tensorflow/issues/6003", dataType)
|
||||
}
|
||||
}
|
||||
|
||||
// nativeEndian is the byte order for the local platform. Used to send back and
|
||||
// forth Tensors with the C API. We test for endianness at runtime because
|
||||
// some architectures can be booted into different endian modes.
|
||||
var nativeEndian binary.ByteOrder
|
||||
|
||||
func init() {
|
||||
buf := [2]byte{}
|
||||
*(*uint16)(unsafe.Pointer(&buf[0])) = uint16(0xABCD)
|
||||
|
||||
switch buf {
|
||||
case [2]byte{0xCD, 0xAB}:
|
||||
nativeEndian = binary.LittleEndian
|
||||
case [2]byte{0xAB, 0xCD}:
|
||||
nativeEndian = binary.BigEndian
|
||||
default:
|
||||
panic("Could not determine native endianness.")
|
||||
}
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ package tensorflow
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
"io"
|
||||
"reflect"
|
||||
"testing"
|
||||
@ -276,6 +277,7 @@ func TestReadTensorReadAll(t *testing.T) {
|
||||
}
|
||||
|
||||
func benchmarkNewTensor(b *testing.B, v interface{}) {
|
||||
b.ReportAllocs()
|
||||
for i := 0; i < b.N; i++ {
|
||||
if t, err := NewTensor(v); err != nil || t == nil {
|
||||
b.Fatalf("(%v, %v)", t, err)
|
||||
@ -283,32 +285,68 @@ func benchmarkNewTensor(b *testing.B, v interface{}) {
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkNewTensor(b *testing.B) {
|
||||
var (
|
||||
// Some sample sizes from the Inception image labeling model.
|
||||
// Where input tensors correspond to a 224x224 RGB image
|
||||
// flattened into a vector.
|
||||
vector [224 * 224 * 3]int32
|
||||
)
|
||||
b.Run("[150528]", func(b *testing.B) { benchmarkNewTensor(b, vector) })
|
||||
}
|
||||
func benchmarkValueTensor(b *testing.B, v interface{}) {
|
||||
t, err := NewTensor(v)
|
||||
if err != nil {
|
||||
b.Fatalf("(%v, %v)", t, err)
|
||||
}
|
||||
b.ReportAllocs()
|
||||
b.ResetTimer()
|
||||
|
||||
func benchmarkDecodeTensor(b *testing.B, t *Tensor) {
|
||||
for i := 0; i < b.N; i++ {
|
||||
_ = t.Value()
|
||||
}
|
||||
}
|
||||
|
||||
func BenchmarkDecodeTensor(b *testing.B) {
|
||||
var (
|
||||
// Some sample sizes from the Inception image labeling model.
|
||||
// Where input tensors correspond to a 224x224 RGB image
|
||||
// flattened into a vector.
|
||||
vector [224 * 224 * 3]int32
|
||||
)
|
||||
t, err := NewTensor(vector)
|
||||
if err != nil {
|
||||
b.Fatalf("(%v, %v)", t, err)
|
||||
func BenchmarkTensor(b *testing.B) {
|
||||
// Some sample sizes from the Inception image labeling model.
|
||||
// Where input tensors correspond to a 224x224 RGB image
|
||||
// flattened into a vector.
|
||||
var vector [224 * 224 * 3]int32
|
||||
var arrays [100][100][100]int32
|
||||
|
||||
l3 := make([][][]float32, 100)
|
||||
l2 := make([][]float32, 100*100)
|
||||
l1 := make([]float32, 100*100*100)
|
||||
for i := range l2 {
|
||||
l2[i] = l1[i*100 : (i+1)*100]
|
||||
}
|
||||
b.Run("[150528]", func(b *testing.B) { benchmarkDecodeTensor(b, t) })
|
||||
for i := range l3 {
|
||||
l3[i] = l2[i*100 : (i+1)*100]
|
||||
}
|
||||
|
||||
s1 := make([]string, 100*100*100)
|
||||
s2 := make([][]string, 100*100)
|
||||
s3 := make([][][]string, 100)
|
||||
for i := range s1 {
|
||||
s1[i] = "cheesit"
|
||||
}
|
||||
for i := range s2 {
|
||||
s2[i] = s1[i*100 : (i+1)*100]
|
||||
}
|
||||
for i := range s3 {
|
||||
s3[i] = s2[i*100 : (i+1)*100]
|
||||
}
|
||||
|
||||
tests := []interface{}{
|
||||
vector,
|
||||
arrays,
|
||||
l1,
|
||||
l2,
|
||||
l3,
|
||||
s1,
|
||||
s2,
|
||||
s3,
|
||||
}
|
||||
b.Run("New", func(b *testing.B) {
|
||||
for _, test := range tests {
|
||||
b.Run(fmt.Sprintf("%T", test), func(b *testing.B) { benchmarkNewTensor(b, test) })
|
||||
}
|
||||
})
|
||||
b.Run("Value", func(b *testing.B) {
|
||||
for _, test := range tests {
|
||||
b.Run(fmt.Sprintf("%T", test), func(b *testing.B) { benchmarkValueTensor(b, test) })
|
||||
}
|
||||
})
|
||||
|
||||
}
|
||||
|
@ -440,76 +440,6 @@ typedef struct TfLiteTensor {
|
||||
// `dims_signature` contains [1, -1, -1, 3]).
|
||||
const TfLiteIntArray* dims_signature;
|
||||
} TfLiteTensor;
|
||||
#else
|
||||
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
|
||||
// contains only the minimum fields required to initialize and prepare a micro
|
||||
// inference graph. The fields in this struct have been ordered from
|
||||
// largest-to-smallest for optimal struct sizeof.
|
||||
//
|
||||
// NOTE: This flag is opt-in only at compile time.
|
||||
typedef struct TfLiteTensor {
|
||||
// TODO(b/155784997): Consider consolidating these quantization fields:
|
||||
// Quantization information. Replaces params field above.
|
||||
TfLiteQuantization quantization;
|
||||
|
||||
// Quantization information.
|
||||
TfLiteQuantizationParams params;
|
||||
|
||||
// A union of data pointers. The appropriate type should be used for a typed
|
||||
// tensor based on `type`.
|
||||
TfLitePtrUnion data;
|
||||
|
||||
// A pointer to a structure representing the dimensionality interpretation
|
||||
// that the buffer should have. NOTE: the product of elements of `dims`
|
||||
// and the element datatype size should be equal to `bytes` below.
|
||||
TfLiteIntArray* dims;
|
||||
|
||||
// The number of bytes required to store the data of this Tensor. I.e.
|
||||
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
|
||||
// type is kTfLiteFloat32 and dims = {3, 2} then
|
||||
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
|
||||
size_t bytes;
|
||||
|
||||
// The data type specification for data stored in `data`. This affects
|
||||
// what member of `data` union should be used.
|
||||
TfLiteType type;
|
||||
|
||||
// How memory is mapped
|
||||
// kTfLiteMmapRo: Memory mapped read only.
|
||||
// i.e. weights
|
||||
// kTfLiteArenaRw: Arena allocated read write memory
|
||||
// (i.e. temporaries, outputs).
|
||||
TfLiteAllocationType allocation_type;
|
||||
|
||||
// True if the tensor is a variable.
|
||||
bool is_variable;
|
||||
} TfLiteTensor;
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Free data memory of tensor `t`.
|
||||
void TfLiteTensorDataFree(TfLiteTensor* t);
|
||||
|
||||
// Free quantization data.
|
||||
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
|
||||
|
||||
// Free sparsity parameters.
|
||||
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
|
||||
|
||||
// Free memory of tensor `t`.
|
||||
void TfLiteTensorFree(TfLiteTensor* t);
|
||||
|
||||
// Set all of a tensor's fields (and free any previously allocated data).
|
||||
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
|
||||
TfLiteQuantizationParams quantization, char* buffer,
|
||||
size_t size, TfLiteAllocationType allocation_type,
|
||||
const void* allocation, bool is_variable,
|
||||
TfLiteTensor* tensor);
|
||||
|
||||
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
|
||||
// types other than kTfLiteDynamic will be ignored.
|
||||
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// A structure representing an instance of a node.
|
||||
// This structure only exhibits the inputs, outputs and user defined data, not
|
||||
@ -547,6 +477,112 @@ typedef struct TfLiteNode {
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
struct TfLiteDelegate* delegate;
|
||||
} TfLiteNode;
|
||||
#else
|
||||
// NOTE: This flag is opt-in only at compile time.
|
||||
//
|
||||
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
|
||||
// contains only the minimum fields required to initialize and prepare a micro
|
||||
// inference graph. The fields in this struct have been ordered from
|
||||
// largest-to-smallest for optimal struct sizeof.
|
||||
//
|
||||
// This struct does not use:
|
||||
// - allocation
|
||||
// - buffer_handle
|
||||
// - data_is_stale
|
||||
// - delegate
|
||||
// - dims_signature
|
||||
// - name
|
||||
// - sparsity
|
||||
typedef struct TfLiteTensor {
|
||||
// TODO(b/155784997): Consider consolidating these quantization fields:
|
||||
// Quantization information. Replaces params field above.
|
||||
TfLiteQuantization quantization;
|
||||
|
||||
// Quantization information.
|
||||
TfLiteQuantizationParams params;
|
||||
|
||||
// A union of data pointers. The appropriate type should be used for a typed
|
||||
// tensor based on `type`.
|
||||
TfLitePtrUnion data;
|
||||
|
||||
// A pointer to a structure representing the dimensionality interpretation
|
||||
// that the buffer should have. NOTE: the product of elements of `dims`
|
||||
// and the element datatype size should be equal to `bytes` below.
|
||||
TfLiteIntArray* dims;
|
||||
|
||||
// The number of bytes required to store the data of this Tensor. I.e.
|
||||
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
|
||||
// type is kTfLiteFloat32 and dims = {3, 2} then
|
||||
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
|
||||
size_t bytes;
|
||||
|
||||
// The data type specification for data stored in `data`. This affects
|
||||
// what member of `data` union should be used.
|
||||
TfLiteType type;
|
||||
|
||||
// How memory is mapped
|
||||
// kTfLiteMmapRo: Memory mapped read only.
|
||||
// i.e. weights
|
||||
// kTfLiteArenaRw: Arena allocated read write memory
|
||||
// (i.e. temporaries, outputs).
|
||||
TfLiteAllocationType allocation_type;
|
||||
|
||||
// True if the tensor is a variable.
|
||||
bool is_variable;
|
||||
} TfLiteTensor;
|
||||
|
||||
// Specific reduced TfLiteNode struct for TF Micro runtime. This struct contains
|
||||
// only the minimum fields required to represent a node.
|
||||
//
|
||||
// This struct does not use:
|
||||
// - delegate
|
||||
// - intermediates
|
||||
// - temporaries
|
||||
typedef struct TfLiteNode {
|
||||
// Inputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* inputs;
|
||||
|
||||
// Outputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* outputs;
|
||||
|
||||
// Opaque data provided by the node implementer through `Registration.init`.
|
||||
void* user_data;
|
||||
|
||||
// Opaque data provided to the node if the node is a builtin. This is usually
|
||||
// a structure defined in builtin_op_data.h
|
||||
void* builtin_data;
|
||||
|
||||
// Custom initial data. This is the opaque data provided in the flatbuffer.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
const void* custom_initial_data;
|
||||
int custom_initial_data_size;
|
||||
} TfLiteNode;
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Free data memory of tensor `t`.
|
||||
void TfLiteTensorDataFree(TfLiteTensor* t);
|
||||
|
||||
// Free quantization data.
|
||||
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
|
||||
|
||||
// Free sparsity parameters.
|
||||
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
|
||||
|
||||
// Free memory of tensor `t`.
|
||||
void TfLiteTensorFree(TfLiteTensor* t);
|
||||
|
||||
// Set all of a tensor's fields (and free any previously allocated data).
|
||||
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
|
||||
TfLiteQuantizationParams quantization, char* buffer,
|
||||
size_t size, TfLiteAllocationType allocation_type,
|
||||
const void* allocation, bool is_variable,
|
||||
TfLiteTensor* tensor);
|
||||
|
||||
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
|
||||
// types other than kTfLiteDynamic will be ignored.
|
||||
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
//
|
||||
|
@ -263,6 +263,43 @@ TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseCeil(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseConcatenation(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteConcatenationParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteConcatenationParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const ConcatenationOptions* schema_params =
|
||||
op->builtin_options_as_ConcatenationOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
params->activation =
|
||||
ConvertActivation(schema_params->fused_activation_function());
|
||||
params->axis = schema_params->axis();
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
@ -295,6 +332,14 @@ TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseCos(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
@ -339,6 +384,22 @@ TfLiteStatus ParseDequantize(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseEqual(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseFloor(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
@ -385,6 +446,53 @@ TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseGreater(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseGreaterEqual(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLitePoolParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLitePoolParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const Pool2DOptions* schema_params = op->builtin_options_as_Pool2DOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
params->padding = ConvertPadding(schema_params->padding());
|
||||
params->stride_width = schema_params->stride_w();
|
||||
params->stride_height = schema_params->stride_h();
|
||||
params->filter_width = schema_params->filter_width();
|
||||
params->filter_height = schema_params->filter_height();
|
||||
params->activation =
|
||||
ConvertActivation(schema_params->fused_activation_function());
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseReshape(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
@ -532,6 +640,19 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
return ParseArgMin(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_AVERAGE_POOL_2D: {
|
||||
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_CEIL: {
|
||||
return ParseCeil(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_CONCATENATION: {
|
||||
return ParseConcatenation(op, op_type, error_reporter, allocator,
|
||||
builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_CONV_2D: {
|
||||
return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
@ -546,11 +667,32 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_FLOOR: {
|
||||
return ParseFloor(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_FULLY_CONNECTED: {
|
||||
return ParseFullyConnected(op, op_type, error_reporter, allocator,
|
||||
builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_GREATER: {
|
||||
return ParseGreater(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_GREATER_EQUAL: {
|
||||
return ParseGreaterEqual(op, op_type, error_reporter, allocator,
|
||||
builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_MAX_POOL_2D: {
|
||||
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_L2_POOL_2D: {
|
||||
return ParsePool(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_QUANTIZE: {
|
||||
return ParseQuantize(op, op_type, error_reporter, allocator,
|
||||
builtin_data);
|
||||
@ -592,23 +734,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_AVERAGE_POOL_2D:
|
||||
case BuiltinOperator_MAX_POOL_2D:
|
||||
case BuiltinOperator_L2_POOL_2D: {
|
||||
auto params = safe_allocator.Allocate<TfLitePoolParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* pool_params = op->builtin_options_as_Pool2DOptions()) {
|
||||
params->padding = ConvertPadding(pool_params->padding());
|
||||
params->stride_width = pool_params->stride_w();
|
||||
params->stride_height = pool_params->stride_h();
|
||||
params->filter_width = pool_params->filter_width();
|
||||
params->filter_height = pool_params->filter_height();
|
||||
params->activation =
|
||||
ConvertActivation(pool_params->fused_activation_function());
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_UNIDIRECTIONAL_SEQUENCE_RNN: {
|
||||
auto params = safe_allocator.Allocate<TfLiteSequenceRNNParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
@ -666,18 +791,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_HASHTABLE_LOOKUP:
|
||||
// no-op.
|
||||
return kTfLiteOk;
|
||||
case BuiltinOperator_CONCATENATION: {
|
||||
auto params = safe_allocator.Allocate<TfLiteConcatenationParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* concatenation_params =
|
||||
op->builtin_options_as_ConcatenationOptions()) {
|
||||
params->activation = ConvertActivation(
|
||||
concatenation_params->fused_activation_function());
|
||||
params->axis = concatenation_params->axis();
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_MUL: {
|
||||
auto params = safe_allocator.Allocate<TfLiteMulParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
@ -1102,10 +1215,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
case BuiltinOperator_EQUAL:
|
||||
case BuiltinOperator_EXP:
|
||||
case BuiltinOperator_EXPAND_DIMS:
|
||||
case BuiltinOperator_CEIL:
|
||||
case BuiltinOperator_FLOOR:
|
||||
case BuiltinOperator_GREATER:
|
||||
case BuiltinOperator_GREATER_EQUAL:
|
||||
case BuiltinOperator_HARD_SWISH:
|
||||
case BuiltinOperator_LESS:
|
||||
case BuiltinOperator_LESS_EQUAL:
|
||||
|
@ -91,10 +91,23 @@ TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseCeil(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseConcatenation(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseCos(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseDepthwiseConv2D(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
@ -105,11 +118,32 @@ TfLiteStatus ParseDequantize(const Operator* op, BuiltinOperator op_type,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseEqual(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseFloor(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseFullyConnected(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseGreater(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseGreaterEqual(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
void** builtin_data);
|
||||
|
||||
TfLiteStatus ParsePool(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseQuantize(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator,
|
||||
|
@ -49,6 +49,7 @@ cc_library(
|
||||
":tensor_type",
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:access_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"//tensorflow/lite/delegates/gpu/common:types",
|
||||
"//tensorflow/lite/delegates/gpu/common:util",
|
||||
@ -84,6 +85,7 @@ cc_library(
|
||||
":gpu_object",
|
||||
":opencl_wrapper",
|
||||
":util",
|
||||
"//tensorflow/lite/delegates/gpu/common:data_type",
|
||||
"//tensorflow/lite/delegates/gpu/common:status",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
@ -330,6 +332,7 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "gpu_object",
|
||||
srcs = ["gpu_object.cc"],
|
||||
hdrs = ["gpu_object.h"],
|
||||
deps = [
|
||||
":opencl_wrapper",
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "absl/strings/str_split.h"
|
||||
#include "absl/strings/substitute.h"
|
||||
#include "tensorflow/lite/delegates/gpu/cl/tensor_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -457,21 +458,15 @@ std::string Arguments::GetListOfArgs() {
|
||||
for (auto& t : buffers_) {
|
||||
const std::string type_name =
|
||||
t.second.data_type == DataType::FLOAT32 ? "float" : "half";
|
||||
std::string memory_type;
|
||||
switch (t.second.memory_type) {
|
||||
case MemoryType::GLOBAL:
|
||||
memory_type = "__global";
|
||||
break;
|
||||
case MemoryType::CONSTANT:
|
||||
memory_type = "__constant";
|
||||
break;
|
||||
case MemoryType::LOCAL:
|
||||
memory_type = "__local";
|
||||
break;
|
||||
std::string attributes;
|
||||
for (const auto& attr : t.second.attributes) {
|
||||
attributes += absl::StrCat(" __attribute__((", attr, "))");
|
||||
}
|
||||
AppendArgument(absl::StrCat(memory_type, " ", type_name,
|
||||
t.second.element_size, "* ", t.first),
|
||||
&result);
|
||||
AppendArgument(
|
||||
absl::StrCat(MemoryTypeToCLType(t.second.memory_type), " ",
|
||||
ToCLDataType(t.second.data_type, t.second.element_size),
|
||||
"* ", t.first, attributes),
|
||||
&result);
|
||||
}
|
||||
for (auto& t : image_buffers_) {
|
||||
AppendArgument(absl::StrCat(GetImageModifier(t.second.access_type),
|
||||
|
@ -15,6 +15,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/buffer.h"
|
||||
|
||||
#include <string>
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/common/data_type.h"
|
||||
#include "tensorflow/lite/delegates/gpu/common/status.h"
|
||||
|
||||
namespace tflite {
|
||||
@ -51,6 +54,7 @@ GPUResources BufferDescriptor::GetGPUResources(AccessType access_type) const {
|
||||
desc.access_type = access_type;
|
||||
desc.element_size = element_size;
|
||||
desc.memory_type = memory_type;
|
||||
desc.attributes = attributes;
|
||||
resources.buffers.push_back({"buffer", desc});
|
||||
return resources;
|
||||
}
|
||||
@ -61,7 +65,7 @@ absl::Status BufferDescriptor::PerformSelector(
|
||||
if (selector == "Read") {
|
||||
return PerformReadSelector(args, result);
|
||||
} else if (selector == "GetPtr") {
|
||||
return PerformGetPtrSelector(args, result);
|
||||
return PerformGetPtrSelector(args, template_args, result);
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"BufferDescriptor don't have selector with name - ", selector));
|
||||
@ -80,13 +84,34 @@ absl::Status BufferDescriptor::PerformReadSelector(
|
||||
}
|
||||
|
||||
absl::Status BufferDescriptor::PerformGetPtrSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
if (!args.empty()) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("BufferDescriptor GetPtr require zero arguments, but ",
|
||||
args.size(), " was passed"));
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const {
|
||||
if (args.size() > 1) {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"BufferDescriptor GetPtr require one or zero arguments, but ",
|
||||
args.size(), " was passed"));
|
||||
}
|
||||
if (template_args.size() > 1) {
|
||||
return absl::NotFoundError(
|
||||
absl::StrCat("BufferDescriptor GetPtr require one or zero teemplate "
|
||||
"arguments, but ",
|
||||
template_args.size(), " was passed"));
|
||||
}
|
||||
std::string conversion;
|
||||
if (template_args.size() == 1) {
|
||||
const std::string type_name = ToCLDataType(element_type, element_size);
|
||||
if (type_name != template_args[0]) {
|
||||
conversion = absl::StrCat("(", MemoryTypeToCLType(memory_type), " ",
|
||||
template_args[0], "*)&");
|
||||
}
|
||||
}
|
||||
if (args.empty()) {
|
||||
*result = absl::StrCat(conversion, "buffer");
|
||||
} else if (conversion.empty()) {
|
||||
*result = absl::StrCat("(buffer + ", args[0], ")");
|
||||
} else {
|
||||
*result = absl::StrCat(conversion, "buffer[", args[0], "]");
|
||||
}
|
||||
*result = "buffer";
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -30,9 +30,10 @@ namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
struct BufferDescriptor : public GPUObjectDescriptor {
|
||||
DataType element_type; // FLOAT32 or FLOAT16
|
||||
DataType element_type;
|
||||
int element_size;
|
||||
MemoryType memory_type = MemoryType::GLOBAL;
|
||||
std::vector<std::string> attributes;
|
||||
|
||||
absl::Status PerformSelector(const std::string& selector,
|
||||
const std::vector<std::string>& args,
|
||||
@ -42,8 +43,9 @@ struct BufferDescriptor : public GPUObjectDescriptor {
|
||||
GPUResources GetGPUResources(AccessType access_type) const override;
|
||||
absl::Status PerformReadSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
absl::Status PerformGetPtrSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
absl::Status PerformGetPtrSelector(
|
||||
const std::vector<std::string>& args,
|
||||
const std::vector<std::string>& template_args, std::string* result) const;
|
||||
};
|
||||
|
||||
// Buffer represent linear GPU data storage with arbitrary data format.
|
||||
|
37
tensorflow/lite/delegates/gpu/cl/gpu_object.cc
Normal file
37
tensorflow/lite/delegates/gpu/cl/gpu_object.cc
Normal file
@ -0,0 +1,37 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/gpu/cl/gpu_object.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace gpu {
|
||||
namespace cl {
|
||||
|
||||
std::string MemoryTypeToCLType(MemoryType type) {
|
||||
switch (type) {
|
||||
case MemoryType::GLOBAL:
|
||||
return "__global";
|
||||
case MemoryType::CONSTANT:
|
||||
return "__constant";
|
||||
break;
|
||||
case MemoryType::LOCAL:
|
||||
return "__local";
|
||||
}
|
||||
return "";
|
||||
}
|
||||
|
||||
} // namespace cl
|
||||
} // namespace gpu
|
||||
} // namespace tflite
|
@ -56,11 +56,14 @@ struct GPUImageBufferDescriptor {
|
||||
|
||||
enum class MemoryType { GLOBAL, CONSTANT, LOCAL };
|
||||
|
||||
std::string MemoryTypeToCLType(MemoryType type);
|
||||
|
||||
struct GPUBufferDescriptor {
|
||||
DataType data_type;
|
||||
AccessType access_type;
|
||||
int element_size;
|
||||
MemoryType memory_type = MemoryType::GLOBAL;
|
||||
std::vector<std::string> attributes;
|
||||
cl_mem memory;
|
||||
};
|
||||
|
||||
|
@ -28,21 +28,17 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GenerateConvolutionTransposedCode(
|
||||
const OperationDef& op_def, int src_depth, int dst_channels,
|
||||
const int2& kernel_size, const CLDevice& device,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GenerateConvolutionTransposedCode(const OperationDef& op_def,
|
||||
int src_depth, int dst_channels,
|
||||
const int2& kernel_size,
|
||||
Arguments* args) {
|
||||
args->AddObjectRef(
|
||||
"src_tensor", AccessType::READ,
|
||||
absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
const std::string channel_x = dst_channels == 1 ? "" : ".x";
|
||||
const std::vector<std::string> postfix = {channel_x, ".y", ".z", ".w"};
|
||||
const std::vector<std::string> channel = {".x", ".y", ".z", ".w"};
|
||||
@ -62,36 +58,33 @@ std::string GenerateConvolutionTransposedCode(
|
||||
break;
|
||||
}
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
||||
c += " __constant FLT4* filters";
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size, \n";
|
||||
c += " FLT4 bias_value \n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
c += " int X = linear_id / dst_size.w;\n";
|
||||
c += " int B = linear_id % dst_size.w;\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
c += " args.dst_tensor.SetBatchRef(B);\n";
|
||||
c += " args.src_tensor.SetBatchRef(B);\n";
|
||||
} else {
|
||||
c += " int X = get_global_id(0);\n";
|
||||
}
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " if (X >= src_size.x || Y >= src_size.y) return;\n";
|
||||
c += " if (X >= args.src_tensor.Width() || Y >= args.src_tensor.Height()) "
|
||||
"return;\n";
|
||||
c += " " + accum_type + " r[" + std::to_string(kernel_size.y) + "][" +
|
||||
std::to_string(kernel_size.x) + "];\n";
|
||||
c += " {\n";
|
||||
c += " FLT4 src = " + src_tensor.ReadWHSB("X", "Y", "0", batch_id) + ";\n";
|
||||
c += " FLT4 src = args.src_tensor.Read(X, Y, 0);\n";
|
||||
int index = 0;
|
||||
for (int y = 0; y < kernel_size.y; ++y) {
|
||||
for (int x = 0; x < kernel_size.x; ++x) {
|
||||
std::string r_s =
|
||||
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
|
||||
for (int d = 0; d < dst_channels; ++d) {
|
||||
c += r_s + postfix[d] + " = dot(src, filters[" + std::to_string(index) +
|
||||
"]);\n";
|
||||
c += r_s + postfix[d] + " = dot(src, args.weights.Read(" +
|
||||
std::to_string(index) + "));\n";
|
||||
index++;
|
||||
}
|
||||
}
|
||||
@ -100,15 +93,15 @@ std::string GenerateConvolutionTransposedCode(
|
||||
for (int i = 1; i < src_depth; ++i) {
|
||||
c += " if (X > " + std::to_string(-i) +
|
||||
") { // always true, to reduce registers usage\n";
|
||||
c += " FLT4 src = " +
|
||||
src_tensor.ReadWHSB("X", "Y", std::to_string(i), batch_id) + ";\n";
|
||||
c +=
|
||||
" FLT4 src = args.src_tensor.Read(X, Y, " + std::to_string(i) + ");\n";
|
||||
for (int y = 0; y < kernel_size.y; ++y) {
|
||||
for (int x = 0; x < kernel_size.x; ++x) {
|
||||
std::string r_s =
|
||||
" r[" + std::to_string(y) + "][" + std::to_string(x) + "]";
|
||||
for (int d = 0; d < dst_channels; ++d) {
|
||||
c += r_s + postfix[d] + " += dot(src, filters[" +
|
||||
std::to_string(index) + "]);\n";
|
||||
c += r_s + postfix[d] + " += dot(src, args.weights.Read(" +
|
||||
std::to_string(index) + "));\n";
|
||||
index++;
|
||||
}
|
||||
}
|
||||
@ -121,21 +114,16 @@ std::string GenerateConvolutionTransposedCode(
|
||||
for (int x = 0; x < kernel_size.x; ++x) {
|
||||
const std::string x_coord = "X + " + std::to_string(x);
|
||||
const std::string y_coord = "Y + " + std::to_string(y);
|
||||
c += " if (" + x_coord + " < dst_size.x && " + y_coord +
|
||||
" < dst_size.y) {\n";
|
||||
c += " FLT4 result = bias_value;\n";
|
||||
c += " if (" + x_coord + " < args.dst_tensor.Width() && " + y_coord +
|
||||
" < args.dst_tensor.Height()) {\n";
|
||||
c += " FLT4 result = args.weights.Read(" + std::to_string(index) +
|
||||
");\n";
|
||||
for (int d = 0; d < dst_channels; ++d) {
|
||||
c += " result" + channel[d] + " += r[" + std::to_string(y) + "][" +
|
||||
std::to_string(x) + "]" + postfix[d] + ";\n";
|
||||
}
|
||||
const std::string x_3dcoord = op_def.IsBatchSupported()
|
||||
? "(" + x_coord + ") * dst_size.w + B"
|
||||
: x_coord;
|
||||
const LinkingContext context{"result", x_3dcoord, y_coord, "0"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " +
|
||||
dst_tensor.WriteWHSB("result", x_coord, y_coord, "0", batch_id) +
|
||||
"\n";
|
||||
c += " args.dst_tensor.Write(result, " + x_coord + ", " + y_coord +
|
||||
", 0);\n";
|
||||
c += " }\n";
|
||||
}
|
||||
}
|
||||
@ -150,19 +138,11 @@ ConvolutionTransposedThin::ConvolutionTransposedThin(
|
||||
: GPUOperation(definition),
|
||||
kernel_size_(attr.weights.shape.w, attr.weights.shape.h),
|
||||
src_channels_(attr.weights.shape.i),
|
||||
dst_channels_(attr.weights.shape.o) {
|
||||
float4 bias_value(0.0f);
|
||||
for (int i = 0; i < attr.weights.shape.o; ++i) {
|
||||
bias_value[i] = attr.bias.data[i];
|
||||
}
|
||||
bias_value_ = FLT4(definition_.precision, bias_value);
|
||||
}
|
||||
dst_channels_(attr.weights.shape.o) {}
|
||||
|
||||
ConvolutionTransposedThin::ConvolutionTransposedThin(
|
||||
ConvolutionTransposedThin&& operation)
|
||||
: GPUOperation(std::move(operation)),
|
||||
weights_buf_(std::move(operation.weights_buf_)),
|
||||
bias_value_(std::move(operation.bias_value_)),
|
||||
kernel_size_(operation.kernel_size_),
|
||||
src_channels_(operation.src_channels_),
|
||||
dst_channels_(operation.dst_channels_),
|
||||
@ -172,8 +152,6 @@ ConvolutionTransposedThin::ConvolutionTransposedThin(
|
||||
ConvolutionTransposedThin& ConvolutionTransposedThin::operator=(
|
||||
ConvolutionTransposedThin&& operation) {
|
||||
if (this != &operation) {
|
||||
weights_buf_ = std::move(operation.weights_buf_);
|
||||
bias_value_ = std::move(operation.bias_value_);
|
||||
std::swap(kernel_size_, operation.kernel_size_);
|
||||
std::swap(src_channels_, operation.src_channels_);
|
||||
std::swap(dst_channels_, operation.dst_channels_);
|
||||
@ -186,9 +164,15 @@ ConvolutionTransposedThin& ConvolutionTransposedThin::operator=(
|
||||
|
||||
absl::Status ConvolutionTransposedThin::Compile(
|
||||
const CreationContext& creation_context) {
|
||||
const auto code = GenerateConvolutionTransposedCode(
|
||||
std::string code = GenerateConvolutionTransposedCode(
|
||||
definition_, DivideRoundUp(src_channels_, 4), dst_channels_, kernel_size_,
|
||||
*creation_context.device, linked_operations_);
|
||||
&args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
|
||||
std::vector<CompilerOptions> options;
|
||||
if (definition_.precision == CalculationsPrecision::F16 &&
|
||||
@ -202,15 +186,10 @@ absl::Status ConvolutionTransposedThin::Compile(
|
||||
}
|
||||
|
||||
absl::Status ConvolutionTransposedThin::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_buf_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(bias_value_));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 ConvolutionTransposedThin::GetGridSize() const {
|
||||
@ -248,7 +227,7 @@ absl::Status CreateConvolutionTransposedThin(
|
||||
}
|
||||
*result = ConvolutionTransposedThin(definition, attr);
|
||||
RETURN_IF_ERROR(
|
||||
result->UploadWeights(attr.weights, creation_context.context));
|
||||
result->UploadData(attr.weights, attr.bias, creation_context.context));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -58,8 +58,9 @@ class ConvolutionTransposedThin : public GPUOperation {
|
||||
ConvolutionTransposedThin(const OperationDef& definition,
|
||||
const ConvolutionTransposedAttributes& attr);
|
||||
template <DataType T>
|
||||
absl::Status UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
CLContext* context);
|
||||
absl::Status UploadData(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const tflite::gpu::Tensor<Linear, T>& biases,
|
||||
CLContext* context);
|
||||
|
||||
template <DataType S, typename T>
|
||||
void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
|
||||
@ -68,9 +69,6 @@ class ConvolutionTransposedThin : public GPUOperation {
|
||||
absl::Status BindArguments();
|
||||
int3 GetGridSize() const;
|
||||
|
||||
Buffer weights_buf_;
|
||||
FLT4 bias_value_;
|
||||
|
||||
int2 kernel_size_;
|
||||
int src_channels_;
|
||||
int dst_channels_;
|
||||
@ -80,25 +78,50 @@ class ConvolutionTransposedThin : public GPUOperation {
|
||||
};
|
||||
|
||||
template <DataType T>
|
||||
absl::Status ConvolutionTransposedThin::UploadWeights(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||
absl::Status ConvolutionTransposedThin::UploadData(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
|
||||
const int src_depth = DivideRoundUp(src_channels_, 4);
|
||||
const int elements_count =
|
||||
kernel_size_.x * kernel_size_.y * src_depth * 4 * dst_channels_;
|
||||
const int flt4_count =
|
||||
kernel_size_.x * kernel_size_.y * src_depth * dst_channels_;
|
||||
|
||||
const int float4_size =
|
||||
definition_.precision == CalculationsPrecision::F32 ? 16 : 8;
|
||||
if (definition_.GetDataType() == DataType::FLOAT32) {
|
||||
std::vector<float4> gpu_data(elements_count);
|
||||
const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
|
||||
|
||||
BufferDescriptor desc;
|
||||
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
|
||||
desc.element_size = 4;
|
||||
desc.memory_type = MemoryType::CONSTANT;
|
||||
|
||||
Buffer weights_buffer;
|
||||
if (f32_weights) {
|
||||
std::vector<float4> gpu_data(flt4_count);
|
||||
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
|
||||
return CreateReadOnlyBuffer(float4_size * elements_count, gpu_data.data(),
|
||||
context, &weights_buf_);
|
||||
float4 bias_value(0.0f);
|
||||
for (int i = 0; i < weights.shape.o; ++i) {
|
||||
bias_value[i] = biases.data[i];
|
||||
}
|
||||
gpu_data.push_back(bias_value);
|
||||
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(float4) * gpu_data.size(),
|
||||
gpu_data.data(), context,
|
||||
&weights_buffer));
|
||||
} else {
|
||||
std::vector<half4> gpu_data(elements_count);
|
||||
std::vector<half4> gpu_data(flt4_count);
|
||||
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
|
||||
return CreateReadOnlyBuffer(float4_size * elements_count, gpu_data.data(),
|
||||
context, &weights_buf_);
|
||||
half4 bias_value(0.0f);
|
||||
for (int i = 0; i < weights.shape.o; ++i) {
|
||||
bias_value[i] = biases.data[i];
|
||||
}
|
||||
gpu_data.push_back(bias_value);
|
||||
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(half4) * gpu_data.size(),
|
||||
gpu_data.data(), context,
|
||||
&weights_buffer));
|
||||
}
|
||||
|
||||
args_.AddObject("weights", AccessType::READ,
|
||||
absl::make_unique<Buffer>(std::move(weights_buffer)),
|
||||
absl::make_unique<BufferDescriptor>(desc));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <DataType S, typename T>
|
||||
|
@ -147,6 +147,13 @@ absl::Status TensorDescriptor::PerformSelector(
|
||||
} else if (selector == "Slices") {
|
||||
*result = "slices";
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "SliceStride") {
|
||||
if (IsBatchedWidth()) {
|
||||
*result = "width_batched * height";
|
||||
} else {
|
||||
*result = "width * height";
|
||||
}
|
||||
return absl::OkStatus();
|
||||
} else if (selector == "Channels") {
|
||||
*result = "channels";
|
||||
return absl::OkStatus();
|
||||
|
@ -55,9 +55,12 @@ inline void GetActivationMinMax(FusedActivationFunctionType ac,
|
||||
}
|
||||
}
|
||||
|
||||
inline float ActivationFunctionWithMinMax(float x, float output_activation_min,
|
||||
float output_activation_max) {
|
||||
return std::min(std::max(x, output_activation_min), output_activation_max);
|
||||
template <typename T>
|
||||
inline T ActivationFunctionWithMinMax(T x, T output_activation_min,
|
||||
T output_activation_max) {
|
||||
using std::max;
|
||||
using std::min;
|
||||
return min(max(x, output_activation_min), output_activation_max);
|
||||
}
|
||||
|
||||
// Legacy function, left for compatibility only.
|
||||
|
@ -2766,37 +2766,39 @@ inline void SubNonBroadcast(const ArithmeticParams& params,
|
||||
}
|
||||
}
|
||||
|
||||
inline void SubWithActivation(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const int32* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const int32* input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
int32* output_data) {
|
||||
ruy::profiler::ScopeLabel label("SubWithActivation/int32");
|
||||
const int flat_size =
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.quantized_activation_min,
|
||||
params.quantized_activation_max);
|
||||
}
|
||||
inline void SetActivationMinMax(const ArithmeticParams& params,
|
||||
int32* activation_min, int32* activation_max) {
|
||||
*activation_min = params.quantized_activation_min;
|
||||
*activation_max = params.quantized_activation_max;
|
||||
}
|
||||
|
||||
inline void SubWithActivation(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const float* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const float* input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
ruy::profiler::ScopeLabel label("SubWithActivation/float");
|
||||
inline void SetActivationMinMax(const ArithmeticParams& params,
|
||||
float* activation_min, float* activation_max) {
|
||||
*activation_min = params.float_activation_min;
|
||||
*activation_max = params.float_activation_max;
|
||||
}
|
||||
|
||||
inline void SetActivationMinMax(const ArithmeticParams& params,
|
||||
int64_t* activation_min,
|
||||
int64_t* activation_max) {
|
||||
*activation_min = params.int64_activation_min;
|
||||
*activation_max = params.int64_activation_max;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void SubWithActivation(
|
||||
const ArithmeticParams& params, const RuntimeShape& input1_shape,
|
||||
const T* input1_data, const RuntimeShape& input2_shape,
|
||||
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
|
||||
ruy::profiler::ScopeLabel label("SubWithActivation_optimized");
|
||||
const int flat_size =
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
T activation_min, activation_max;
|
||||
SetActivationMinMax(params, &activation_min, &activation_max);
|
||||
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.float_activation_min,
|
||||
params.float_activation_max);
|
||||
input1_data[i] - input2_data[i], activation_min, activation_max);
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -1495,6 +1495,7 @@ inline void GatherNd(const RuntimeShape& params_shape,
|
||||
}
|
||||
}
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
template <typename IndicesT = int32>
|
||||
inline void GatherNdString(const RuntimeShape& params_shape,
|
||||
const TfLiteTensor* params_data,
|
||||
@ -1517,6 +1518,7 @@ inline void GatherNdString(const RuntimeShape& params_shape,
|
||||
}
|
||||
buffer.WriteToTensor(output_data, /*new_shape=*/nullptr);
|
||||
}
|
||||
#endif
|
||||
|
||||
template <typename IndicesT, typename UpdatesT>
|
||||
inline void ScatterNd(const RuntimeShape& indices_shape,
|
||||
|
@ -260,6 +260,45 @@ inline void BroadcastSubSlow(const ArithmeticParams& params,
|
||||
NDOpsHelper<N>(output_desc, sub_func);
|
||||
}
|
||||
|
||||
template <int N = 5>
|
||||
void BroadcastSubSlow(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const int64_t* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const int64_t* input2_data,
|
||||
const RuntimeShape& output_shape, int64_t* output_data) {
|
||||
ruy::profiler::ScopeLabel label("BroadcastSubSlow/int64");
|
||||
TFLITE_DCHECK_LE(input1_shape.DimensionsCount(), N);
|
||||
TFLITE_DCHECK_LE(input2_shape.DimensionsCount(), N);
|
||||
TFLITE_DCHECK_LE(output_shape.DimensionsCount(), N);
|
||||
NdArrayDesc<N> desc1;
|
||||
NdArrayDesc<N> desc2;
|
||||
NdArrayDesc<N> output_desc;
|
||||
NdArrayDescsForElementwiseBroadcast(input1_shape, input2_shape, &desc1,
|
||||
&desc2);
|
||||
CopyDimsToDesc(RuntimeShape::ExtendedShape(N, output_shape), &output_desc);
|
||||
|
||||
// In Tensorflow, the dimensions are canonically named (batch_number, row,
|
||||
// col, channel), with extents (batches, height, width, depth), with the
|
||||
// trailing dimension changing most rapidly (channels has the smallest stride,
|
||||
// typically 1 element).
|
||||
//
|
||||
// In generated C code, we store arrays with the dimensions reversed. The
|
||||
// first dimension has smallest stride.
|
||||
//
|
||||
// We name our variables by their Tensorflow convention, but generate C code
|
||||
// nesting loops such that the innermost loop has the smallest stride for the
|
||||
// best cache behavior.
|
||||
auto sub_func = [&](int indexes[N]) {
|
||||
output_data[SubscriptToIndex(output_desc, indexes)] =
|
||||
ActivationFunctionWithMinMax(
|
||||
input1_data[SubscriptToIndex(desc1, indexes)] -
|
||||
input2_data[SubscriptToIndex(desc2, indexes)],
|
||||
params.int64_activation_min, params.int64_activation_max);
|
||||
};
|
||||
NDOpsHelper<N>(output_desc, sub_func);
|
||||
}
|
||||
|
||||
template <typename T, int N = 5>
|
||||
void BroadcastSubSlow(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape, const T* input1_data,
|
||||
@ -434,40 +473,42 @@ void Sub(const ArithmeticParams& params, const RuntimeShape& input1_shape,
|
||||
}
|
||||
}
|
||||
|
||||
inline void SubWithActivation(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const int32* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const int32* input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
int32* output_data) {
|
||||
inline void SetActivationMinMax(const ArithmeticParams& params,
|
||||
int32* activation_min, int32* activation_max) {
|
||||
*activation_min = params.quantized_activation_min;
|
||||
*activation_max = params.quantized_activation_max;
|
||||
}
|
||||
|
||||
inline void SetActivationMinMax(const ArithmeticParams& params,
|
||||
float* activation_min, float* activation_max) {
|
||||
*activation_min = params.float_activation_min;
|
||||
*activation_max = params.float_activation_max;
|
||||
}
|
||||
|
||||
inline void SetActivationMinMax(const ArithmeticParams& params,
|
||||
int64_t* activation_min,
|
||||
int64_t* activation_max) {
|
||||
*activation_min = params.int64_activation_min;
|
||||
*activation_max = params.int64_activation_max;
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
inline void SubWithActivation(
|
||||
const ArithmeticParams& params, const RuntimeShape& input1_shape,
|
||||
const T* input1_data, const RuntimeShape& input2_shape,
|
||||
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
|
||||
ruy::profiler::ScopeLabel label("SubWithActivation");
|
||||
const int flat_size =
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
T activation_min, activation_max;
|
||||
SetActivationMinMax(params, &activation_min, &activation_max);
|
||||
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.quantized_activation_min,
|
||||
params.quantized_activation_max);
|
||||
input1_data[i] - input2_data[i], activation_min, activation_max);
|
||||
}
|
||||
}
|
||||
|
||||
inline void SubWithActivation(const ArithmeticParams& params,
|
||||
const RuntimeShape& input1_shape,
|
||||
const float* input1_data,
|
||||
const RuntimeShape& input2_shape,
|
||||
const float* input2_data,
|
||||
const RuntimeShape& output_shape,
|
||||
float* output_data) {
|
||||
const int flat_size =
|
||||
MatchingElementsSize(input1_shape, input2_shape, output_shape);
|
||||
for (int i = 0; i < flat_size; ++i) {
|
||||
output_data[i] = ActivationFunctionWithMinMax(
|
||||
input1_data[i] - input2_data[i], params.float_activation_min,
|
||||
params.float_activation_max);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
} // namespace reference_ops
|
||||
} // namespace tflite
|
||||
|
||||
|
@ -765,12 +765,17 @@ struct ArithmeticParams {
|
||||
int input1_shift;
|
||||
int32 input2_multiplier;
|
||||
int input2_shift;
|
||||
|
||||
// TODO(b/158622529): Union the following activation params.
|
||||
// uint8, etc, activation params.
|
||||
int32 quantized_activation_min;
|
||||
int32 quantized_activation_max;
|
||||
// float activation params.
|
||||
float float_activation_min;
|
||||
float float_activation_max;
|
||||
// int64 activation params.
|
||||
int64_t int64_activation_min;
|
||||
int64_t int64_activation_max;
|
||||
|
||||
// Processed output dimensions.
|
||||
// Let input "a" be the one that broadcasts in the faster-changing dimension.
|
||||
@ -1114,6 +1119,12 @@ inline void SetActivationParams(int32 min, int32 max, P* params) {
|
||||
params->quantized_activation_max = max;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline void SetActivationParams(int64_t min, int64_t max, P* params) {
|
||||
params->int64_activation_min = min;
|
||||
params->int64_activation_max = max;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline void GetActivationParams(const P& params, int32* min, int32* max) {
|
||||
*min = params.quantized_activation_min;
|
||||
@ -1126,6 +1137,11 @@ inline void GetActivationParams(const P& params, float* min, float* max) {
|
||||
*max = params.float_activation_max;
|
||||
}
|
||||
|
||||
template <typename P>
|
||||
inline void GetActivationParams(const P& params, int64_t* min, int64_t* max) {
|
||||
*min = params.int64_activation_min;
|
||||
*max = params.int64_activation_max;
|
||||
}
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_KERNELS_INTERNAL_TYPES_H_
|
||||
|
@ -44,6 +44,7 @@ inline TfLiteTensor* GetOutput(TfLiteContext* context, const TfLiteNode* node,
|
||||
int index) {
|
||||
return &context->tensors[node->outputs->data[index]];
|
||||
}
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
inline TfLiteTensor* GetTemporary(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
return &context->tensors[node->temporaries->data[index]];
|
||||
@ -52,11 +53,12 @@ inline const TfLiteTensor* GetIntermediates(TfLiteContext* context,
|
||||
const TfLiteNode* node, int index) {
|
||||
return &context->tensors[node->intermediates->data[index]];
|
||||
}
|
||||
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
||||
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
||||
inline int NumIntermediates(const TfLiteNode* node) {
|
||||
return node->intermediates->size;
|
||||
}
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
inline int NumInputs(const TfLiteNode* node) { return node->inputs->size; }
|
||||
inline int NumOutputs(const TfLiteNode* node) { return node->outputs->size; }
|
||||
|
||||
inline int64_t NumElements(const TfLiteIntArray* dims) {
|
||||
int64_t count = 1;
|
||||
|
@ -142,8 +142,8 @@ BuiltinOpResolver::BuiltinOpResolver() {
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 2);
|
||||
AddBuiltin(BuiltinOperator_SUB, Register_SUB(),
|
||||
/* min_version */ 1,
|
||||
/* max_version */ 5);
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 5);
|
||||
AddBuiltin(BuiltinOperator_SPLIT, Register_SPLIT(),
|
||||
/* min_version = */ 1,
|
||||
/* max_version = */ 4);
|
||||
|
@ -340,6 +340,11 @@ void EvalSub(TfLiteContext* context, TfLiteNode* node, TfLiteSubParams* params,
|
||||
EvalSubImpl<kernel_type, float>(context, node, params, data, input1,
|
||||
input2, requires_broadcast, output);
|
||||
break;
|
||||
case kTfLiteInt64:
|
||||
EvalSubImpl<kernel_type, int64_t>(context, node, params, data, input1,
|
||||
input2, requires_broadcast, output);
|
||||
break;
|
||||
|
||||
default:
|
||||
TF_LITE_KERNEL_LOG(context, "output type %s is not supported.",
|
||||
TfLiteTypeGetName(output->type));
|
||||
@ -434,7 +439,8 @@ TfLiteStatus Eval(TfLiteContext* context, TfLiteNode* node) {
|
||||
const TfLiteTensor* input2 = GetInput(context, node, kInputTensor2);
|
||||
TfLiteTensor* output = GetOutput(context, node, kOutputTensor);
|
||||
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32) {
|
||||
if (output->type == kTfLiteFloat32 || output->type == kTfLiteInt32 ||
|
||||
output->type == kTfLiteInt64) {
|
||||
EvalSub<kernel_type>(context, node, params, data, input1, input2, output);
|
||||
} else if (output->type == kTfLiteUInt8 || output->type == kTfLiteInt8 ||
|
||||
output->type == kTfLiteInt16) {
|
||||
|
@ -63,6 +63,13 @@ class IntegerSubOpModel : public BaseSubOpModel {
|
||||
std::vector<int32_t> GetOutput() { return ExtractVector<int32_t>(output_); }
|
||||
};
|
||||
|
||||
class Int64SubOpModel : public BaseSubOpModel {
|
||||
public:
|
||||
using BaseSubOpModel::BaseSubOpModel;
|
||||
|
||||
std::vector<int64_t> GetOutput() { return ExtractVector<int64_t>(output_); }
|
||||
};
|
||||
|
||||
class QuantizedSubOpModel : public BaseSubOpModel {
|
||||
public:
|
||||
using BaseSubOpModel::BaseSubOpModel;
|
||||
@ -213,6 +220,57 @@ TEST(IntegerSubOpModel, WithBroadcast) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Int64SubOpModel, NoActivation) {
|
||||
Int64SubOpModel m({TensorType_INT64, {1, 2, 2, 1}},
|
||||
{TensorType_INT64, {1, 2, 2, 1}}, {TensorType_INT64, {}},
|
||||
ActivationFunctionType_NONE);
|
||||
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8});
|
||||
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3}));
|
||||
}
|
||||
|
||||
TEST(Int64SubOpModel, ActivationRELU_N1_TO_1) {
|
||||
Int64SubOpModel m({TensorType_INT64, {1, 2, 2, 1}},
|
||||
{TensorType_INT64, {1, 2, 2, 1}}, {TensorType_INT64, {}},
|
||||
ActivationFunctionType_RELU_N1_TO_1);
|
||||
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8});
|
||||
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-1, 0, 1, 1}));
|
||||
}
|
||||
|
||||
TEST(Int64SubOpModel, VariousInputShapes) {
|
||||
std::vector<std::vector<int>> test_shapes = {
|
||||
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}};
|
||||
for (int i = 0; i < test_shapes.size(); ++i) {
|
||||
Int64SubOpModel m({TensorType_INT64, test_shapes[i]},
|
||||
{TensorType_INT64, test_shapes[i]},
|
||||
{TensorType_INT64, {}}, ActivationFunctionType_NONE);
|
||||
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
|
||||
m.PopulateTensor<int64_t>(m.input2(), {1, 2, 3, 5, 11, 1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(), ElementsAreArray({-21, 0, 4, 3, 0, 19}))
|
||||
<< "With shape number " << i;
|
||||
}
|
||||
}
|
||||
|
||||
TEST(Int64SubOpModel, WithBroadcast) {
|
||||
std::vector<std::vector<int>> test_shapes = {
|
||||
{6}, {2, 3}, {2, 1, 3}, {1, 3, 1, 2}, {1, 3, 1, 2, 1}};
|
||||
for (int i = 0; i < test_shapes.size(); ++i) {
|
||||
Int64SubOpModel m({TensorType_INT64, test_shapes[i]},
|
||||
{TensorType_INT64, {}}, // always a scalar
|
||||
{TensorType_INT64, {}}, ActivationFunctionType_NONE);
|
||||
m.PopulateTensor<int64_t>(m.input1(), {-20, 2, 7, 8, 11, 20});
|
||||
m.PopulateTensor<int64_t>(m.input2(), {1});
|
||||
m.Invoke();
|
||||
EXPECT_THAT(m.GetOutput(),
|
||||
ElementsAreArray(ArrayFloatNear({-21, 1, 6, 7, 10, 19})))
|
||||
<< "With shape number " << i;
|
||||
}
|
||||
}
|
||||
|
||||
template <TensorType tensor_type, typename integer_dtype>
|
||||
void QuantizedTestsNoActivation() {
|
||||
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
|
||||
|
@ -53,19 +53,14 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
// There is 1 output at index 3 in the tensors array.
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
// There are no temporaries.
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(conv_params);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TfLiteStatus prepare_status = registration->prepare(&context, &node);
|
||||
|
@ -66,19 +66,14 @@ TfLiteStatus ValidateDepthwiseConvGoldens(TfLiteTensor* tensors,
|
||||
// There is 1 output at index 3 in the tensors array.
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
// There are no intermediates.
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TfLiteStatus prepare_status = registration->prepare(&context, &node);
|
||||
|
@ -33,20 +33,23 @@ limitations under the License.
|
||||
namespace {
|
||||
|
||||
// Create an area of memory to use for input, output, and intermediate arrays.
|
||||
constexpr int tensor_arena_size = 73 * 1024;
|
||||
uint8_t tensor_arena[tensor_arena_size];
|
||||
// Align arena to 16 bytes to avoid alignment warnings on certain platforms.
|
||||
constexpr int tensor_arena_size = 21 * 1024;
|
||||
alignas(16) uint8_t tensor_arena[tensor_arena_size];
|
||||
// A random number generator seed to generate input values.
|
||||
constexpr int kRandomSeed = 42;
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
MicroBenchmarkRunner<int16_t> runner(g_keyword_scrambled_model_data,
|
||||
tensor_arena, tensor_arena_size,
|
||||
kRandomSeed);
|
||||
MicroBenchmarkRunner<int16_t>& GetBenchmarkRunner() {
|
||||
// NOLINTNEXTLINE
|
||||
static MicroBenchmarkRunner<int16_t> runner(
|
||||
g_keyword_scrambled_model_data, tensor_arena, tensor_arena_size, 0);
|
||||
return runner;
|
||||
}
|
||||
|
||||
void KeywordRunTenIerations() {
|
||||
// TODO(b/152644476): Add a way to run more than a single deterministic input.
|
||||
for (int i = 0; i < 10; i++) {
|
||||
runner.RunSingleIterationRandomInput();
|
||||
GetBenchmarkRunner().RunSingleIterationRandomInput();
|
||||
}
|
||||
}
|
||||
|
||||
@ -54,7 +57,7 @@ void KeywordRunTenIerations() {
|
||||
|
||||
TF_LITE_MICRO_BENCHMARKS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_BENCHMARK(runner.RunSingleIterationRandomInput());
|
||||
TF_LITE_MICRO_BENCHMARK(GetBenchmarkRunner().RunSingleIterationRandomInput());
|
||||
|
||||
TF_LITE_MICRO_BENCHMARK(KeywordRunTenIerations());
|
||||
|
||||
|
@ -27,10 +27,6 @@ limitations under the License.
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
// Create an area of memory to use for input, output, and intermediate arrays.
|
||||
constexpr int tensor_arena_size = 73 * 1024;
|
||||
uint8_t tensor_arena[tensor_arena_size];
|
||||
|
||||
/*
|
||||
* Person Detection benchmark. Evaluates runtime performance of the visual
|
||||
* wakewords person detection model. This is the same model found in
|
||||
@ -40,24 +36,28 @@ uint8_t tensor_arena[tensor_arena_size];
|
||||
namespace {
|
||||
|
||||
// Create an area of memory to use for input, output, and intermediate arrays.
|
||||
// Align arena to 16 bytes to avoid alignment warnings on certain platforms.
|
||||
constexpr int tensor_arena_size = 95 * 1024;
|
||||
uint8_t tensor_arena[tensor_arena_size];
|
||||
alignas(16) uint8_t tensor_arena[tensor_arena_size];
|
||||
|
||||
// NOLINTNEXTLINE
|
||||
MicroBenchmarkRunner<uint8_t> runner(g_person_detect_model_data, tensor_arena,
|
||||
tensor_arena_size, 0);
|
||||
MicroBenchmarkRunner<uint8_t>& GetBenchmarkRunner() {
|
||||
// NOLINTNEXTLINE
|
||||
static MicroBenchmarkRunner<uint8_t> runner(
|
||||
g_person_detect_model_data, tensor_arena, tensor_arena_size, 0);
|
||||
return runner;
|
||||
}
|
||||
|
||||
void PersonDetectionTenIerationsWithPerson() {
|
||||
// TODO(b/152644476): Add a way to run more than a single deterministic input.
|
||||
for (int i = 0; i < 10; i++) {
|
||||
runner.RunSingleIterationCustomInput(g_person_data);
|
||||
GetBenchmarkRunner().RunSingleIterationCustomInput(g_person_data);
|
||||
}
|
||||
}
|
||||
|
||||
void PersonDetectionTenIerationsWithoutPerson() {
|
||||
// TODO(b/152644476): Add a way to run more than a single deterministic input.
|
||||
for (int i = 0; i < 10; i++) {
|
||||
runner.RunSingleIterationCustomInput(g_no_person_data);
|
||||
GetBenchmarkRunner().RunSingleIterationCustomInput(g_no_person_data);
|
||||
}
|
||||
}
|
||||
|
||||
@ -65,7 +65,8 @@ void PersonDetectionTenIerationsWithoutPerson() {
|
||||
|
||||
TF_LITE_MICRO_BENCHMARKS_BEGIN
|
||||
|
||||
TF_LITE_MICRO_BENCHMARK(runner.RunSingleIterationCustomInput(g_person_data));
|
||||
TF_LITE_MICRO_BENCHMARK(
|
||||
GetBenchmarkRunner().RunSingleIterationCustomInput(g_person_data));
|
||||
TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithPerson());
|
||||
TF_LITE_MICRO_BENCHMARK(PersonDetectionTenIerationsWithoutPerson());
|
||||
|
||||
|
@ -60,12 +60,10 @@ void TestReluFloat(const int* input_dims_data, const float* input_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -116,12 +114,10 @@ void TestRelu6Float(const int* input_dims_data, const float* input_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -177,12 +173,10 @@ void TestReluUint8(const int* input_dims_data, const float* input_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -242,12 +236,10 @@ void TestRelu6Uint8(const int* input_dims_data, const float* input_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -306,12 +298,10 @@ void TestReluInt8(const int* input_dims_data, const float* input_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -372,12 +362,10 @@ void TestRelu6Int8(const int* input_dims_data, const float* input_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
@ -89,18 +89,14 @@ void ValidateAddGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -46,17 +46,13 @@ void ValidateArgMinMaxGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
@ -46,17 +46,13 @@ void TestCeil(const int* input_dims_data, const float* input_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
|
@ -56,17 +56,12 @@ TfLiteNode PrepareCircularBufferInt8(const int* input_dims_data,
|
||||
// There is one output - tensor 1.
|
||||
const int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
// There are no intermediates.
|
||||
const int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -104,17 +99,12 @@ TfLiteStatus InvokeCircularBufferInt8(const int* input_dims_data,
|
||||
// There is one output - tensor 1.
|
||||
const int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
// There are no intermediates.
|
||||
const int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
node->inputs = inputs_array;
|
||||
node->outputs = outputs_array;
|
||||
node->temporaries = temporaries_array;
|
||||
node->builtin_data = nullptr;
|
||||
node->custom_initial_data = nullptr;
|
||||
node->custom_initial_data_size = 0;
|
||||
node->delegate = nullptr;
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
|
||||
|
@ -44,18 +44,14 @@ void TestComparison(tflite::BuiltinOperator op, TfLiteTensor* tensors,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
const int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
const int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -57,19 +57,18 @@ void TestConcatenateTwoInputs(std::initializer_list<int> input1_dims_data,
|
||||
.activation = kTfLiteActNone // Only activation supported in this impl
|
||||
};
|
||||
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
||||
@ -116,19 +115,18 @@ void TestConcatenateQuantizedTwoInputs(
|
||||
.activation = kTfLiteActNone // Only activation supported in this impl
|
||||
};
|
||||
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
||||
|
@ -76,18 +76,14 @@ TfLiteStatus ValidateConvGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(conv_params);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node));
|
||||
|
@ -73,18 +73,14 @@ TfLiteStatus ValidateDepthwiseConvGoldens(const T* expected_output_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_ENSURE_OK(context, registration->prepare(&context, &node));
|
||||
}
|
||||
|
@ -48,18 +48,14 @@ void ValidateDequantizeGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -54,23 +54,18 @@ void TestElementwiseFloat(tflite::BuiltinOperator op,
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
auto inputs_array_data = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer(inputs_array_data);
|
||||
auto outputs_array_data = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer(outputs_array_data);
|
||||
auto temporaries_array_data = {0};
|
||||
TfLiteIntArray* temporaries_array =
|
||||
IntArrayFromInitializer(temporaries_array_data);
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -119,19 +114,19 @@ void TestElementwiseBool(tflite::BuiltinOperator op,
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 1});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -47,18 +47,13 @@ void TestFloor(const int* input_dims_data, const float* input_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int intermediates_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array =
|
||||
IntArrayFromInts(intermediates_array_data);
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
|
@ -72,18 +72,14 @@ TfLiteStatus TestFullyConnectedFloat(
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -151,18 +147,14 @@ TfLiteStatus TestFullyConnectedQuantized(
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_ENSURE_OK(&context, registration->prepare(&context, &node));
|
||||
|
@ -112,18 +112,14 @@ void TestL2Normalization(const int* input_dims_data, const T* input_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
|
@ -51,19 +51,18 @@ void TestLogicalOp(tflite::BuiltinOperator op,
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -62,12 +62,10 @@ void TestLogisticFloat(std::initializer_list<int> input_dims_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -122,12 +120,10 @@ void TestLogisticInt8(std::initializer_list<int> input_dims_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
@ -52,19 +52,18 @@ void TestMaxMinFloat(tflite::BuiltinOperator op,
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -108,19 +107,18 @@ void TestMaxMinQuantized(
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -162,19 +160,18 @@ void TestMaxMinQuantizedInt32(
|
||||
const TfLiteRegistration* registration = resolver.FindOp(op);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -76,7 +76,6 @@ void TestMulFloat(std::initializer_list<int> input1_dims_data,
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -148,7 +147,6 @@ void TestMulQuantized(std::initializer_list<int> input1_dims_data,
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -50,17 +50,14 @@ void TestNegFloat(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
@ -64,19 +64,19 @@ void TestPackTwoInputsFloat(std::initializer_list<int> input1_dims_data,
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -141,19 +141,19 @@ void TestPackThreeInputsFloat(std::initializer_list<int> input1_dims_data,
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({3, 0, 1, 2});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 3});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
int inputs_array_data[] = {3, 0, 1, 2};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -214,19 +214,18 @@ void TestPackTwoInputsQuantized(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -284,19 +283,18 @@ void TestPackTwoInputsQuantized32(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 2});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -39,17 +39,13 @@ TfLiteStatus ValidatePadGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
|
||||
TF_LITE_ENSURE_EQ(&context, kTfLiteOk,
|
||||
registration->prepare(&context, &node));
|
||||
@ -76,17 +72,13 @@ TfLiteStatus ValidatePadV2Goldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->prepare);
|
||||
// Prepare should catch dimension mismatches.
|
||||
TfLiteStatus prepare_status = registration->prepare(&context, &node);
|
||||
|
@ -66,18 +66,14 @@ void TestAveragePoolingFloat(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -138,18 +134,14 @@ void TestAveragePoolingQuantized(
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -209,18 +201,14 @@ void TestMaxPoolFloat(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -283,18 +271,14 @@ void TestMaxPoolQuantized(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
@ -58,16 +58,13 @@ void TestPreluFloat(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -122,16 +119,13 @@ void TestPreluQuantized(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
@ -50,18 +50,14 @@ void ValidateQuantizeGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -74,12 +74,10 @@ TfLiteStatus ValidateReduceGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(params);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -29,56 +29,29 @@ namespace tflite {
|
||||
namespace testing {
|
||||
namespace {
|
||||
|
||||
// If expected output is empty, the test is expected to fail.
|
||||
template <typename T>
|
||||
void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
|
||||
void TestReshapeImpl(TfLiteContext* context, TfLiteNode* node,
|
||||
TfLiteTensor* output_tensor,
|
||||
std::initializer_list<T> expected_output,
|
||||
std::initializer_list<int> expected_dims,
|
||||
bool expect_failure) {
|
||||
TfLiteContext context;
|
||||
TfLiteTensor tensors[3];
|
||||
TfLiteNode node;
|
||||
if (shape_tensor == nullptr) {
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
tensors[0] = *input_tensor;
|
||||
tensors[1] = *output_tensor,
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
node.inputs = IntArrayFromInitializer({1, 0});
|
||||
node.outputs = IntArrayFromInitializer({1, 1});
|
||||
} else {
|
||||
constexpr int inputs_size = 2;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
tensors[0] = *input_tensor;
|
||||
tensors[1] = *shape_tensor;
|
||||
tensors[2] = *output_tensor;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
node.inputs = IntArrayFromInitializer({2, 0, 1});
|
||||
node.outputs = IntArrayFromInitializer({1, 2});
|
||||
}
|
||||
|
||||
::tflite::AllOpsResolver resolver;
|
||||
const TfLiteRegistration* registration =
|
||||
resolver.FindOp(tflite::BuiltinOperator_RESHAPE);
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration);
|
||||
|
||||
void* user_data = nullptr;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
node->user_data = user_data;
|
||||
node->builtin_data = nullptr;
|
||||
node->custom_initial_data = nullptr;
|
||||
node->custom_initial_data_size = 0;
|
||||
|
||||
TF_LITE_MICRO_EXPECT_EQ(registration->init, nullptr);
|
||||
TF_LITE_MICRO_EXPECT_EQ(registration->free, nullptr);
|
||||
|
||||
if (registration->prepare) {
|
||||
// Error can happen either in Prepare or eval stage.
|
||||
auto status = registration->prepare(&context, &node);
|
||||
auto status = registration->prepare(context, node);
|
||||
if (status != kTfLiteOk && expect_failure) {
|
||||
return;
|
||||
} else {
|
||||
@ -86,11 +59,10 @@ void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
|
||||
}
|
||||
}
|
||||
if (expect_failure) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
|
||||
registration->invoke(&context, &node));
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError, registration->invoke(context, node));
|
||||
return;
|
||||
}
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(context, node));
|
||||
|
||||
const int output_dims_count = ElementCount(*output_tensor->dims);
|
||||
const T* output_data = GetTensorData<T>(output_tensor);
|
||||
@ -105,6 +77,59 @@ void TestReshapeImpl(TfLiteTensor* input_tensor, TfLiteTensor* shape_tensor,
|
||||
}
|
||||
}
|
||||
|
||||
// If expected output is empty, the test is expected to fail.
|
||||
template <typename T>
|
||||
void TestReshapeWithShapeImpl(TfLiteTensor* input_tensor,
|
||||
TfLiteTensor* shape_tensor,
|
||||
TfLiteTensor* output_tensor,
|
||||
std::initializer_list<T> expected_output,
|
||||
std::initializer_list<int> expected_dims,
|
||||
bool expect_failure) {
|
||||
TfLiteContext context;
|
||||
TfLiteTensor tensors[3];
|
||||
TfLiteNode node;
|
||||
constexpr int inputs_size = 2;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
tensors[0] = *input_tensor;
|
||||
tensors[1] = *shape_tensor;
|
||||
tensors[2] = *output_tensor;
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
|
||||
int inputs_data[] = {2, 0, 1};
|
||||
node.inputs = IntArrayFromInts(inputs_data);
|
||||
int outputs_data[] = {1, 2};
|
||||
node.outputs = IntArrayFromInts(outputs_data);
|
||||
|
||||
TestReshapeImpl(&context, &node, output_tensor, expected_output,
|
||||
expected_dims, expect_failure);
|
||||
}
|
||||
|
||||
// If expected output is empty, the test is expected to fail.
|
||||
template <typename T>
|
||||
void TestReshapeWithoutShapeImpl(TfLiteTensor* input_tensor,
|
||||
TfLiteTensor* output_tensor,
|
||||
std::initializer_list<T> expected_output,
|
||||
std::initializer_list<int> expected_dims,
|
||||
bool expect_failure) {
|
||||
TfLiteContext context;
|
||||
TfLiteTensor tensors[3];
|
||||
TfLiteNode node;
|
||||
constexpr int inputs_size = 1;
|
||||
constexpr int outputs_size = 1;
|
||||
constexpr int tensors_size = inputs_size + outputs_size;
|
||||
tensors[0] = *input_tensor;
|
||||
tensors[1] = *output_tensor,
|
||||
PopulateContext(tensors, tensors_size, micro_test::reporter, &context);
|
||||
int inputs_data[] = {1, 0};
|
||||
node.inputs = IntArrayFromInts(inputs_data);
|
||||
int outputs_data[] = {1, 1};
|
||||
node.outputs = IntArrayFromInts(outputs_data);
|
||||
|
||||
TestReshapeImpl(&context, &node, output_tensor, expected_output,
|
||||
expected_dims, expect_failure);
|
||||
}
|
||||
|
||||
template <typename T = float, TfLiteType tensor_input_type = kTfLiteFloat32>
|
||||
void TestReshape(std::initializer_list<int> input_dims_data,
|
||||
std::initializer_list<T> input_data,
|
||||
@ -122,14 +147,14 @@ void TestReshape(std::initializer_list<int> input_dims_data,
|
||||
TfLiteTensor output_tensor =
|
||||
CreateTensor<T, tensor_input_type>(output_data, output_dims);
|
||||
// Reshape param is passed as op's param.
|
||||
TestReshapeImpl<T>(&input_tensor, nullptr, &output_tensor, expected_output,
|
||||
expected_dims, expect_failure);
|
||||
TestReshapeWithoutShapeImpl<T>(&input_tensor, &output_tensor, expected_output,
|
||||
expected_dims, expect_failure);
|
||||
// Reshape param is passed as a tensor.
|
||||
TfLiteIntArray* shape_dims = IntArrayFromInitializer(shape_dims_data);
|
||||
auto shape_tensor =
|
||||
CreateTensor<int32_t, kTfLiteInt32>(shape_data, shape_dims);
|
||||
TestReshapeImpl<T>(&input_tensor, &shape_tensor, &output_tensor,
|
||||
expected_output, expected_dims, expect_failure);
|
||||
TestReshapeWithShapeImpl<T>(&input_tensor, &shape_tensor, &output_tensor,
|
||||
expected_output, expected_dims, expect_failure);
|
||||
}
|
||||
} // namespace
|
||||
} // namespace testing
|
||||
@ -192,19 +217,20 @@ TF_LITE_MICRO_TEST(InvalidShape) {
|
||||
using tflite::testing::CreateFloatTensor;
|
||||
using tflite::testing::IntArrayFromInitializer;
|
||||
using tflite::testing::IntArrayFromInts;
|
||||
TfLiteIntArray* input_dims = IntArrayFromInitializer({3, 1, 2, 2});
|
||||
int input_dims_data[] = {3, 1, 2, 2};
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
auto input_data = {3.0f};
|
||||
auto input_tensor = CreateFloatTensor(input_data, input_dims);
|
||||
float output_data[4];
|
||||
int output_dims_data[6] = {2, 2, 1, 2, 2, 1};
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||
auto output_tensor = CreateFloatTensor(output_data, output_dims);
|
||||
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor
|
||||
nullptr, // shape_tensor
|
||||
&output_tensor, // output_tensor
|
||||
{}, // expected_output
|
||||
{}, // expected_dims
|
||||
true // expect failure
|
||||
tflite::testing::TestReshapeWithoutShapeImpl<float>(
|
||||
&input_tensor, // input_tensor
|
||||
&output_tensor, // output_tensor
|
||||
{}, // expected_output
|
||||
{}, // expected_dims
|
||||
true // expect failure
|
||||
);
|
||||
}
|
||||
|
||||
@ -255,29 +281,32 @@ TF_LITE_MICRO_TEST(LegacyScalarOutput) {
|
||||
using tflite::testing::CreateFloatTensor;
|
||||
using tflite::testing::IntArrayFromInitializer;
|
||||
using tflite::testing::IntArrayFromInts;
|
||||
TfLiteIntArray* input_dims = IntArrayFromInitializer({1, 1});
|
||||
int input_dims_data[] = {1, 1};
|
||||
TfLiteIntArray* input_dims = IntArrayFromInts(input_dims_data);
|
||||
auto input_data = {3.0f};
|
||||
auto input_tensor = CreateFloatTensor(input_data, input_dims);
|
||||
float output_data[1];
|
||||
int output_dims_data[2] = {1, 0};
|
||||
TfLiteIntArray* output_dims = IntArrayFromInts(output_dims_data);
|
||||
auto output_tensor = CreateFloatTensor(output_data, output_dims);
|
||||
TfLiteIntArray* shape_dims = tflite::testing::IntArrayFromInitializer({1, 0});
|
||||
int shape_dims_data[] = {1, 0};
|
||||
TfLiteIntArray* shape_dims = IntArrayFromInts(shape_dims_data);
|
||||
auto shape_tensor =
|
||||
tflite::testing::CreateTensor<int32_t, kTfLiteInt32>({0}, shape_dims);
|
||||
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor
|
||||
&shape_tensor, // shape_tensor
|
||||
&output_tensor, // output_tensor
|
||||
{}, // expected_output
|
||||
{}, // expected_dims
|
||||
true // expect failure
|
||||
tflite::testing::TestReshapeWithShapeImpl<float>(
|
||||
&input_tensor, // input_tensor
|
||||
&shape_tensor, // shape_tensor
|
||||
&output_tensor, // output_tensor
|
||||
{}, // expected_output
|
||||
{}, // expected_dims
|
||||
true // expect failure
|
||||
);
|
||||
tflite::testing::TestReshapeImpl<float>(&input_tensor, // input_tensor
|
||||
nullptr, // shape_tensor
|
||||
&output_tensor, // output_tensor
|
||||
{3}, // expected_output
|
||||
{}, // expected_dims
|
||||
false // expect failure
|
||||
tflite::testing::TestReshapeWithoutShapeImpl<float>(
|
||||
&input_tensor, // input_tensor
|
||||
&output_tensor, // output_tensor
|
||||
{3}, // expected_output
|
||||
{}, // expected_dims
|
||||
false // expect failure
|
||||
);
|
||||
}
|
||||
|
||||
|
@ -78,18 +78,14 @@ void TestResizeNearestNeighbor(const int* input_dims_data, const T* input_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
|
||||
|
@ -46,17 +46,13 @@ void TestRound(const int* input_dims_data, const float* input_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = nullptr;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
TF_LITE_MICRO_EXPECT_NE(nullptr, registration->invoke);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->invoke(&context, &node));
|
||||
for (int i = 0; i < output_dims_count; ++i) {
|
||||
|
@ -59,18 +59,14 @@ void TestSoftmaxFloat(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -124,18 +120,14 @@ void TestSoftmaxQuantized(std::initializer_list<int> input_dims_data,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -188,18 +180,14 @@ void TestSoftmaxQuantizedSigned(
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -76,19 +76,19 @@ void TestSplitTwoOutputsFloat(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {2, 2, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -179,19 +179,19 @@ void TestSplitFourOutputsFloat(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({4, 2, 3, 4, 5});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {4, 2, 3, 4, 5};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -275,19 +275,19 @@ void TestSplitTwoOutputsQuantized(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {2, 2, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -362,19 +362,19 @@ void TestSplitTwoOutputsQuantized32(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({2, 0, 1});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({2, 2, 3});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
|
||||
int inputs_array_data[] = {2, 0, 1};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {2, 2, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -100,12 +100,10 @@ void TestStrideSlide(std::initializer_list<int> input_shape,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
if (expect_prepare_err) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteError,
|
||||
|
@ -89,18 +89,14 @@ void ValidateSubGoldens(TfLiteTensor* tensors, int tensors_size,
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 2};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
int temporaries_array_data[] = {0};
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInts(temporaries_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -201,7 +201,6 @@ void ValidateSVDFGoldens(const int batch_size, const int num_units,
|
||||
node.builtin_data = reinterpret_cast<void*>(¶ms);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TfLiteStatus prepare_status = registration->prepare(&context, &node);
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, prepare_status);
|
||||
@ -275,7 +274,6 @@ void ValidateIntegerSVDFGoldens(const int batch_size, const int num_units,
|
||||
node.builtin_data = reinterpret_cast<void*>(¶ms);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TfLiteStatus prepare_status = registration->prepare(&context, &node);
|
||||
|
@ -62,12 +62,10 @@ void TestTanhFloat(std::initializer_list<int> input_dims_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
@ -122,12 +120,10 @@ void TestTanhInt8(std::initializer_list<int> input_dims_data,
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = nullptr;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = nullptr;
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
}
|
||||
|
@ -79,19 +79,18 @@ void TestUnpackThreeOutputsFloat(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {3, 1, 2, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -156,19 +155,18 @@ void TestUnpackOneOutputFloat(std::initializer_list<int> input_dims_data,
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({1, 1});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {1, 1};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -245,19 +243,18 @@ void TestUnpackThreeOutputsQuantized(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {3, 1, 2, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
@ -338,19 +335,18 @@ void TestUnpackThreeOutputsQuantized32(
|
||||
if (registration->init) {
|
||||
user_data = registration->init(&context, nullptr, 0);
|
||||
}
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInitializer({1, 0});
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInitializer({3, 1, 2, 3});
|
||||
TfLiteIntArray* temporaries_array = IntArrayFromInitializer({0});
|
||||
int inputs_array_data[] = {1, 0};
|
||||
TfLiteIntArray* inputs_array = IntArrayFromInts(inputs_array_data);
|
||||
int outputs_array_data[] = {3, 1, 2, 3};
|
||||
TfLiteIntArray* outputs_array = IntArrayFromInts(outputs_array_data);
|
||||
|
||||
TfLiteNode node;
|
||||
node.inputs = inputs_array;
|
||||
node.outputs = outputs_array;
|
||||
node.temporaries = temporaries_array;
|
||||
node.user_data = user_data;
|
||||
node.builtin_data = reinterpret_cast<void*>(&builtin_data);
|
||||
node.custom_initial_data = nullptr;
|
||||
node.custom_initial_data_size = 0;
|
||||
node.delegate = nullptr;
|
||||
|
||||
if (registration->prepare) {
|
||||
TF_LITE_MICRO_EXPECT_EQ(kTfLiteOk, registration->prepare(&context, &node));
|
||||
|
@ -45,8 +45,8 @@ constexpr int kKeywordModelNodeAndRegistrationCount = 15;
|
||||
// Run this test with '--copt=-DTF_LITE_MICRO_OPTIMIZED_RUNTIME' to get
|
||||
// optimized memory runtime values:
|
||||
#ifdef TF_LITE_STATIC_MEMORY
|
||||
constexpr int kKeywordModelTotalSize = 18448;
|
||||
constexpr int kKeywordModelTailSize = 17776;
|
||||
constexpr int kKeywordModelTotalSize = 18080;
|
||||
constexpr int kKeywordModelTailSize = 17408;
|
||||
#else
|
||||
constexpr int kKeywordModelTotalSize = 21040;
|
||||
constexpr int kKeywordModelTailSize = 20368;
|
||||
@ -65,8 +65,8 @@ constexpr int kTestConvModelNodeAndRegistrationCount = 7;
|
||||
// NOTE: These values are measured on x86-64:
|
||||
// TODO(b/158651472): Consider auditing these values on non-64 bit systems.
|
||||
#ifdef TF_LITE_STATIC_MEMORY
|
||||
constexpr int kTestConvModelTotalSize = 10960;
|
||||
constexpr int kTestConvModelTailSize = 3216;
|
||||
constexpr int kTestConvModelTotalSize = 10784;
|
||||
constexpr int kTestConvModelTailSize = 3040;
|
||||
#else
|
||||
constexpr int kTestConvModelTotalSize = 11680;
|
||||
constexpr int kTestConvModelTailSize = 3936;
|
||||
|
@ -128,26 +128,20 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
}
|
||||
|
||||
TfLiteStatus AddAveragePool2D() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_AVERAGE_POOL_2D,
|
||||
*tflite::ops::micro::Register_AVERAGE_POOL_2D(),
|
||||
ParseOpData);
|
||||
ParsePool);
|
||||
}
|
||||
|
||||
TfLiteStatus AddCeil() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_CEIL,
|
||||
*tflite::ops::micro::Register_CEIL(), ParseOpData);
|
||||
*tflite::ops::micro::Register_CEIL(), ParseCeil);
|
||||
}
|
||||
|
||||
TfLiteStatus AddConcatenation() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_CONCATENATION,
|
||||
*tflite::ops::micro::Register_CONCATENATION(),
|
||||
ParseOpData);
|
||||
ParseConcatenation);
|
||||
}
|
||||
|
||||
TfLiteStatus AddConv2D() {
|
||||
@ -156,10 +150,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
}
|
||||
|
||||
TfLiteStatus AddCos() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_COS, *tflite::ops::micro::Register_COS(),
|
||||
ParseOpData);
|
||||
ParseCos);
|
||||
}
|
||||
|
||||
TfLiteStatus AddDepthwiseConv2D() {
|
||||
@ -175,17 +167,13 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
}
|
||||
|
||||
TfLiteStatus AddEqual() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_EQUAL,
|
||||
*tflite::ops::micro::Register_EQUAL(), ParseOpData);
|
||||
*tflite::ops::micro::Register_EQUAL(), ParseEqual);
|
||||
}
|
||||
|
||||
TfLiteStatus AddFloor() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_FLOOR,
|
||||
*tflite::ops::micro::Register_FLOOR(), ParseOpData);
|
||||
*tflite::ops::micro::Register_FLOOR(), ParseFloor);
|
||||
}
|
||||
|
||||
TfLiteStatus AddFullyConnected() {
|
||||
@ -195,18 +183,14 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
}
|
||||
|
||||
TfLiteStatus AddGreater() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_GREATER,
|
||||
*tflite::ops::micro::Register_GREATER(), ParseOpData);
|
||||
*tflite::ops::micro::Register_GREATER(), ParseGreater);
|
||||
}
|
||||
|
||||
TfLiteStatus AddGreaterEqual() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_GREATER_EQUAL,
|
||||
*tflite::ops::micro::Register_GREATER_EQUAL(),
|
||||
ParseOpData);
|
||||
ParseGreaterEqual);
|
||||
}
|
||||
|
||||
TfLiteStatus AddL2Normalization() {
|
||||
@ -274,10 +258,8 @@ class MicroMutableOpResolver : public MicroOpResolver {
|
||||
}
|
||||
|
||||
TfLiteStatus AddMaxPool2D() {
|
||||
// TODO(b/149408647): Replace ParseOpData with the operator specific parse
|
||||
// function.
|
||||
return AddBuiltin(BuiltinOperator_MAX_POOL_2D,
|
||||
*tflite::ops::micro::Register_MAX_POOL_2D(), ParseOpData);
|
||||
*tflite::ops::micro::Register_MAX_POOL_2D(), ParsePool);
|
||||
}
|
||||
|
||||
TfLiteStatus AddMean() {
|
||||
|
@ -100,7 +100,6 @@ const std::map<string, string>& GetKnownBrokenTests() {
|
||||
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int32)", "112968789"},
|
||||
{R"(^\/floor_mod.*activation=True.*dtype=tf\.int64)", "112968789"},
|
||||
|
||||
{R"(^\/sub.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/div.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/mul.*dtype=tf\.int64)", "119126484"},
|
||||
{R"(^\/add.*dtype=tf\.int64)", "119126484"},
|
||||
|
@ -61,7 +61,7 @@ std::string GetMinimumRuntimeVersionForModel(const Model& model) {
|
||||
{{OperatorType::kSub, 1}, "1.6.0"},
|
||||
{{OperatorType::kSub, 2}, "1.14.0"},
|
||||
{{OperatorType::kSub, 3}, "1.15.0"},
|
||||
{{OperatorType::kSub, 4}, "1.15.0"},
|
||||
{{OperatorType::kSub, 4}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kSub, 5}, kPendingReleaseOpVersion},
|
||||
{{OperatorType::kDiv, 1}, "1.6.0"},
|
||||
{{OperatorType::kBatchToSpaceND, 1}, "1.6.0"},
|
||||
|
@ -440,76 +440,6 @@ typedef struct TfLiteTensor {
|
||||
// `dims_signature` contains [1, -1, -1, 3]).
|
||||
const TfLiteIntArray* dims_signature;
|
||||
} TfLiteTensor;
|
||||
#else
|
||||
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
|
||||
// contains only the minimum fields required to initialize and prepare a micro
|
||||
// inference graph. The fields in this struct have been ordered from
|
||||
// largest-to-smallest for optimal struct sizeof.
|
||||
//
|
||||
// NOTE: This flag is opt-in only at compile time.
|
||||
typedef struct TfLiteTensor {
|
||||
// TODO(b/155784997): Consider consolidating these quantization fields:
|
||||
// Quantization information. Replaces params field above.
|
||||
TfLiteQuantization quantization;
|
||||
|
||||
// Quantization information.
|
||||
TfLiteQuantizationParams params;
|
||||
|
||||
// A union of data pointers. The appropriate type should be used for a typed
|
||||
// tensor based on `type`.
|
||||
TfLitePtrUnion data;
|
||||
|
||||
// A pointer to a structure representing the dimensionality interpretation
|
||||
// that the buffer should have. NOTE: the product of elements of `dims`
|
||||
// and the element datatype size should be equal to `bytes` below.
|
||||
TfLiteIntArray* dims;
|
||||
|
||||
// The number of bytes required to store the data of this Tensor. I.e.
|
||||
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
|
||||
// type is kTfLiteFloat32 and dims = {3, 2} then
|
||||
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
|
||||
size_t bytes;
|
||||
|
||||
// The data type specification for data stored in `data`. This affects
|
||||
// what member of `data` union should be used.
|
||||
TfLiteType type;
|
||||
|
||||
// How memory is mapped
|
||||
// kTfLiteMmapRo: Memory mapped read only.
|
||||
// i.e. weights
|
||||
// kTfLiteArenaRw: Arena allocated read write memory
|
||||
// (i.e. temporaries, outputs).
|
||||
TfLiteAllocationType allocation_type;
|
||||
|
||||
// True if the tensor is a variable.
|
||||
bool is_variable;
|
||||
} TfLiteTensor;
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Free data memory of tensor `t`.
|
||||
void TfLiteTensorDataFree(TfLiteTensor* t);
|
||||
|
||||
// Free quantization data.
|
||||
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
|
||||
|
||||
// Free sparsity parameters.
|
||||
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
|
||||
|
||||
// Free memory of tensor `t`.
|
||||
void TfLiteTensorFree(TfLiteTensor* t);
|
||||
|
||||
// Set all of a tensor's fields (and free any previously allocated data).
|
||||
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
|
||||
TfLiteQuantizationParams quantization, char* buffer,
|
||||
size_t size, TfLiteAllocationType allocation_type,
|
||||
const void* allocation, bool is_variable,
|
||||
TfLiteTensor* tensor);
|
||||
|
||||
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
|
||||
// types other than kTfLiteDynamic will be ignored.
|
||||
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// A structure representing an instance of a node.
|
||||
// This structure only exhibits the inputs, outputs and user defined data, not
|
||||
@ -547,6 +477,112 @@ typedef struct TfLiteNode {
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
struct TfLiteDelegate* delegate;
|
||||
} TfLiteNode;
|
||||
#else
|
||||
// NOTE: This flag is opt-in only at compile time.
|
||||
//
|
||||
// Specific reduced TfLiteTensor struct for TF Micro runtime. This struct
|
||||
// contains only the minimum fields required to initialize and prepare a micro
|
||||
// inference graph. The fields in this struct have been ordered from
|
||||
// largest-to-smallest for optimal struct sizeof.
|
||||
//
|
||||
// This struct does not use:
|
||||
// - allocation
|
||||
// - buffer_handle
|
||||
// - data_is_stale
|
||||
// - delegate
|
||||
// - dims_signature
|
||||
// - name
|
||||
// - sparsity
|
||||
typedef struct TfLiteTensor {
|
||||
// TODO(b/155784997): Consider consolidating these quantization fields:
|
||||
// Quantization information. Replaces params field above.
|
||||
TfLiteQuantization quantization;
|
||||
|
||||
// Quantization information.
|
||||
TfLiteQuantizationParams params;
|
||||
|
||||
// A union of data pointers. The appropriate type should be used for a typed
|
||||
// tensor based on `type`.
|
||||
TfLitePtrUnion data;
|
||||
|
||||
// A pointer to a structure representing the dimensionality interpretation
|
||||
// that the buffer should have. NOTE: the product of elements of `dims`
|
||||
// and the element datatype size should be equal to `bytes` below.
|
||||
TfLiteIntArray* dims;
|
||||
|
||||
// The number of bytes required to store the data of this Tensor. I.e.
|
||||
// (bytes of each element) * dims[0] * ... * dims[n-1]. For example, if
|
||||
// type is kTfLiteFloat32 and dims = {3, 2} then
|
||||
// bytes = sizeof(float) * 3 * 2 = 4 * 3 * 2 = 24.
|
||||
size_t bytes;
|
||||
|
||||
// The data type specification for data stored in `data`. This affects
|
||||
// what member of `data` union should be used.
|
||||
TfLiteType type;
|
||||
|
||||
// How memory is mapped
|
||||
// kTfLiteMmapRo: Memory mapped read only.
|
||||
// i.e. weights
|
||||
// kTfLiteArenaRw: Arena allocated read write memory
|
||||
// (i.e. temporaries, outputs).
|
||||
TfLiteAllocationType allocation_type;
|
||||
|
||||
// True if the tensor is a variable.
|
||||
bool is_variable;
|
||||
} TfLiteTensor;
|
||||
|
||||
// Specific reduced TfLiteNode struct for TF Micro runtime. This struct contains
|
||||
// only the minimum fields required to represent a node.
|
||||
//
|
||||
// This struct does not use:
|
||||
// - delegate
|
||||
// - intermediates
|
||||
// - temporaries
|
||||
typedef struct TfLiteNode {
|
||||
// Inputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* inputs;
|
||||
|
||||
// Outputs to this node expressed as indices into the simulator's tensors.
|
||||
TfLiteIntArray* outputs;
|
||||
|
||||
// Opaque data provided by the node implementer through `Registration.init`.
|
||||
void* user_data;
|
||||
|
||||
// Opaque data provided to the node if the node is a builtin. This is usually
|
||||
// a structure defined in builtin_op_data.h
|
||||
void* builtin_data;
|
||||
|
||||
// Custom initial data. This is the opaque data provided in the flatbuffer.
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
const void* custom_initial_data;
|
||||
int custom_initial_data_size;
|
||||
} TfLiteNode;
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
#ifndef TF_LITE_STATIC_MEMORY
|
||||
// Free data memory of tensor `t`.
|
||||
void TfLiteTensorDataFree(TfLiteTensor* t);
|
||||
|
||||
// Free quantization data.
|
||||
void TfLiteQuantizationFree(TfLiteQuantization* quantization);
|
||||
|
||||
// Free sparsity parameters.
|
||||
void TfLiteSparsityFree(TfLiteSparsity* sparsity);
|
||||
|
||||
// Free memory of tensor `t`.
|
||||
void TfLiteTensorFree(TfLiteTensor* t);
|
||||
|
||||
// Set all of a tensor's fields (and free any previously allocated data).
|
||||
void TfLiteTensorReset(TfLiteType type, const char* name, TfLiteIntArray* dims,
|
||||
TfLiteQuantizationParams quantization, char* buffer,
|
||||
size_t size, TfLiteAllocationType allocation_type,
|
||||
const void* allocation, bool is_variable,
|
||||
TfLiteTensor* tensor);
|
||||
|
||||
// Resize the allocated data of a (dynamic) tensor. Tensors with allocation
|
||||
// types other than kTfLiteDynamic will be ignored.
|
||||
void TfLiteTensorRealloc(size_t num_bytes, TfLiteTensor* tensor);
|
||||
#endif // TF_LITE_STATIC_MEMORY
|
||||
|
||||
// WARNING: This is an experimental interface that is subject to change.
|
||||
//
|
||||
|
@ -466,12 +466,14 @@ int GetBuiltinOperatorVersion(const OpSignature& op_sig) {
|
||||
op_sig.output_types.at(0) == TensorType_INT16) {
|
||||
if (op_sig.options.addsub.pot_scale_int16) {
|
||||
return 5;
|
||||
} else {
|
||||
return 4;
|
||||
}
|
||||
}
|
||||
}
|
||||
if (op_sig.options.addsub.need_broadcast &&
|
||||
op_sig.options.addsub.num_dims > 4) {
|
||||
if (!op_sig.input_types.empty() &&
|
||||
op_sig.input_types.at(0) == TensorType_INT64) {
|
||||
return 4;
|
||||
}
|
||||
if (op_sig.options.broadcast.need_broadcast &&
|
||||
op_sig.options.broadcast.num_dims > 4) {
|
||||
return 3;
|
||||
}
|
||||
if (op_sig.input_types.at(0) == TensorType_INT8) {
|
||||
|
@ -296,6 +296,14 @@ TEST(OpVersionTest, VersioningSubTest) {
|
||||
SimpleVersioningTest(BuiltinOperator_SUB);
|
||||
}
|
||||
|
||||
TEST(OpVersionTest, VersioningSub4Test) {
|
||||
OpSignature fake_op_sig = {
|
||||
.op = BuiltinOperator_SUB,
|
||||
.input_types = std::vector<TensorType>{TensorType_INT64},
|
||||
};
|
||||
EXPECT_EQ(GetBuiltinOperatorVersion(fake_op_sig), 4);
|
||||
}
|
||||
|
||||
void SimpleMulVersioningTest(TensorType data_type, float multiplier,
|
||||
int version) {
|
||||
OpSignature fake_op_sig = {
|
||||
|
@ -79,6 +79,7 @@ std::string FindMinimumRuntimeVersionForOp(tflite::BuiltinOperator op_code,
|
||||
{{BuiltinOperator_SUB, 1}, "1.6.0"},
|
||||
{{BuiltinOperator_SUB, 2}, "1.14.0"},
|
||||
{{BuiltinOperator_SUB, 3}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_SUB, 4}, kPendingReleaseVersion},
|
||||
{{BuiltinOperator_DENSIFY, 1}, "2.2.0"},
|
||||
{{BuiltinOperator_DIV, 1}, "1.6.0"},
|
||||
{{BuiltinOperator_DIV, 2}, kPendingReleaseVersion},
|
||||
|
@ -33,7 +33,7 @@ from tensorflow.python.util.tf_export import tf_export
|
||||
# This value changes every day with an automatic CL. It can be modified in code
|
||||
# via `forward_compatibility_horizon()` or with the environment variable
|
||||
# TF_FORWARD_COMPATIBILITY_DELTA_DAYS, which is added to the compatibility date.
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 6, 23)
|
||||
_FORWARD_COMPATIBILITY_HORIZON = datetime.date(2020, 6, 24)
|
||||
_FORWARD_COMPATIBILITY_DELTA_DAYS_VAR_NAME = "TF_FORWARD_COMPATIBILITY_DELTA_DAYS"
|
||||
_FORWARD_COMPATIBILITY_DATE_NUMBER = None
|
||||
|
||||
|
@ -1116,6 +1116,8 @@ py_library(
|
||||
"//tensorflow/python:framework_ops",
|
||||
"//tensorflow/python:gradients_impl",
|
||||
"//tensorflow/python:init_ops",
|
||||
"//tensorflow/python:init_ops_v2",
|
||||
"//tensorflow/python:math_ops",
|
||||
"//tensorflow/python:training",
|
||||
"//tensorflow/python:util",
|
||||
"//tensorflow/python:variable_scope",
|
||||
@ -1124,7 +1126,6 @@ py_library(
|
||||
"//tensorflow/python/eager:backprop",
|
||||
"//tensorflow/python/eager:context",
|
||||
"//tensorflow/python/eager:test",
|
||||
"//tensorflow/python/keras/layers",
|
||||
"//third_party/py/numpy",
|
||||
],
|
||||
)
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import functools
|
||||
import os
|
||||
import tempfile
|
||||
|
||||
@ -38,11 +39,12 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import test_util
|
||||
from tensorflow.python.keras.layers import core
|
||||
from tensorflow.python.lib.io import tf_record
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gen_math_ops
|
||||
from tensorflow.python.ops import gradients_impl
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import init_ops_v2
|
||||
from tensorflow.python.ops import summary_ops_v2 as summary_ops
|
||||
from tensorflow.python.ops import variable_scope
|
||||
from tensorflow.python.ops import variables
|
||||
@ -114,12 +116,19 @@ def _events_from_logdir(test_case, logdir):
|
||||
class DistributionTestBase(test.TestCase):
|
||||
"""Some tests that should work with any DistributionStrategy."""
|
||||
|
||||
def _create_variable_like_keras_dense_layer(self, name, shape, dtype):
|
||||
initializer = functools.partial(
|
||||
init_ops_v2.GlorotUniform(), shape, dtype=dtype)
|
||||
return variables.Variable(
|
||||
initial_value=initializer, name=name, trainable=True)
|
||||
|
||||
def _test_minimize_loss_eager(self, d):
|
||||
with d.scope():
|
||||
l = core.Dense(1, use_bias=False)
|
||||
|
||||
kernel = self._create_variable_like_keras_dense_layer(
|
||||
name="kernel", shape=(1, 1), dtype=dtypes.float32)
|
||||
def loss(x):
|
||||
y = array_ops.reshape(l(x), []) - array_ops.identity(1.)
|
||||
y = array_ops.reshape(
|
||||
gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
|
||||
return y * y
|
||||
# TODO(isaprykin): Extract implicit_grad+get_filtered_grad_fn into a
|
||||
# common `implicit_grad` function and put it in DistributionStrategy.
|
||||
@ -173,10 +182,12 @@ class DistributionTestBase(test.TestCase):
|
||||
ops.Graph().as_default(), \
|
||||
self.cached_session(config=config) as sess, \
|
||||
d.scope():
|
||||
l = core.Dense(1, use_bias=False)
|
||||
kernel = self._create_variable_like_keras_dense_layer(
|
||||
name="kernel", shape=(1, 1), dtype=dtypes.float32)
|
||||
|
||||
def loss(x):
|
||||
y = array_ops.reshape(l(x), []) - array_ops.identity(1.)
|
||||
y = array_ops.reshape(
|
||||
gen_math_ops.mat_mul(x, kernel), []) - array_ops.identity(1.)
|
||||
return y * y
|
||||
|
||||
grad_fn = backprop.implicit_grad(loss)
|
||||
|
@ -890,13 +890,13 @@ class Context(object):
|
||||
|
||||
@property
|
||||
def executor(self):
|
||||
ensure_initialized()
|
||||
self.ensure_initialized()
|
||||
return executor.Executor(
|
||||
pywrap_tfe.TFE_ContextGetExecutorForThread(self._context_handle))
|
||||
|
||||
@executor.setter
|
||||
def executor(self, e):
|
||||
ensure_initialized()
|
||||
self.ensure_initialized()
|
||||
pywrap_tfe.TFE_ContextSetExecutorForThread(self._context_handle, e.handle())
|
||||
|
||||
@property
|
||||
|
@ -1192,7 +1192,7 @@ class _TapeGradientFunctions(object):
|
||||
def _wrap_backward_function(self, forward_graph, backward, outputs):
|
||||
"""Create a backward function given `outputs` from the forward function."""
|
||||
capture_mapping = dict(
|
||||
zip([ops.tensor_id(t) for t in forward_graph.outputs], outputs))
|
||||
zip((ops.tensor_id(t) for t in forward_graph.outputs), outputs))
|
||||
remapped_captures = [
|
||||
capture_mapping.get(ops.tensor_id(capture), capture)
|
||||
for capture in backward.captured_inputs
|
||||
@ -1489,9 +1489,8 @@ class ConcreteFunction(object):
|
||||
self._captured_closures = self._func_graph.deferred_external_captures
|
||||
structured_outputs = self._func_graph.structured_outputs
|
||||
self._ndarrays_list = (
|
||||
isinstance(structured_outputs, (list, tuple)) and
|
||||
structured_outputs and
|
||||
all([isinstance(o, np_arrays.ndarray) for o in structured_outputs]))
|
||||
isinstance(structured_outputs, (list, tuple)) and structured_outputs and
|
||||
all(isinstance(o, np_arrays.ndarray) for o in structured_outputs))
|
||||
self._ndarray_singleton = isinstance(structured_outputs, np_arrays.ndarray)
|
||||
|
||||
# function_spec defines the structured signature.
|
||||
@ -2199,6 +2198,14 @@ class ConcreteFunction(object):
|
||||
assert self._function_spec is not None
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
arg_names = list(self._function_spec.arg_names)
|
||||
|
||||
# If an explicit input_signature is provided to @tf.function, then any
|
||||
# arguments with defaults that are not covered by that explicit signature
|
||||
# are simply dropped from the signature.
|
||||
# TODO(b/159639913) Look into whether dropping arguments with default values
|
||||
# from the signature is the right thing to do.
|
||||
arg_names = arg_names[:len(arg_specs)]
|
||||
|
||||
if default_values:
|
||||
for i in range(len(arg_names)):
|
||||
if not _contains_type_spec(arg_specs[i]):
|
||||
@ -2248,6 +2255,14 @@ class ConcreteFunction(object):
|
||||
lines = [self._structured_signature_summary(default_values=True)]
|
||||
arg_specs, kwarg_specs = self.structured_input_signature
|
||||
names = list(self._function_spec.arg_names)
|
||||
|
||||
# If an explicit input_signature is provided to @tf.function, then any
|
||||
# arguments with defaults that are not covered by that explicit signature
|
||||
# are simply dropped from the signature.
|
||||
# TODO(b/159639913) Look into whether dropping arguments with default values
|
||||
# from the signature is the right thing to do.
|
||||
names = names[:len(arg_specs)]
|
||||
|
||||
names.extend(sorted(kwarg_specs))
|
||||
specs = list(arg_specs) + list(kwarg_specs.values())
|
||||
# note: we can skip bound args, since we already displayed thier bound
|
||||
@ -2855,7 +2870,6 @@ class Function(object):
|
||||
graph_function, _, _ = self._maybe_define_function(args, kwargs)
|
||||
return graph_function
|
||||
|
||||
# XX TODO: make sure we fix up this path as well!?
|
||||
def _get_concrete_function_internal(self, *args, **kwargs):
|
||||
"""Bypasses error checking when getting a graph function."""
|
||||
graph_function = self._get_concrete_function_internal_garbage_collected(
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user