Merge branch 'master' into addsub_16x8
This commit is contained in:
commit
1cf086cf0b
@ -3,6 +3,10 @@
|
||||
# Targets in this directory are pure C++ "Classes" underlying the C API types
|
||||
# under tf/c/experimental/saved_model/public/. They are subject to change and
|
||||
# have visibility limited to Tensorflow's implementation only.
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
|
||||
package(
|
||||
default_visibility = [
|
||||
@ -47,6 +51,22 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "saved_model_utils",
|
||||
srcs = [
|
||||
"saved_model_utils.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"saved_model_utils.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c:tf_tensor_internal",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tf_saved_model_impl",
|
||||
srcs = [
|
||||
@ -84,3 +104,26 @@ filegroup(
|
||||
],
|
||||
visibility = ["//tensorflow/core:__pkg__"],
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "saved_model_utils_test",
|
||||
srcs = [
|
||||
"saved_model_utils_test.cc",
|
||||
],
|
||||
deps = [
|
||||
":saved_model_utils",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:abstract_tensor_handle",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/c/experimental/saved_model/core/revived_types:constant",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/common_runtime:core_cpu_lib",
|
||||
"//tensorflow/core/common_runtime/eager:context",
|
||||
"//tensorflow/core/common_runtime/eager:core",
|
||||
],
|
||||
)
|
||||
|
@ -0,0 +1,39 @@
|
||||
# This package contains classes corresponding to Revived SavedObjectGraph types
|
||||
# used by SavedModel. See https://cs.opensource.google/tensorflow/tensorflow/+/c575e2ba93c442121d98d3f125d83fed1339924d:tensorflow/core/protobuf/saved_object_graph.proto;l=56-62
|
||||
package(
|
||||
default_visibility = [
|
||||
# Restricting visibility for now
|
||||
"//tensorflow/c/experimental/saved_model/core:__pkg__",
|
||||
],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "constant",
|
||||
srcs = [
|
||||
"constant.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"constant.h",
|
||||
],
|
||||
deps = [
|
||||
":tensorhandle_convertible",
|
||||
"//tensorflow/c:tensor_interface",
|
||||
"//tensorflow/c/eager:immediate_execution_context",
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/common_runtime/eager:tensor_handle",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tensorhandle_convertible",
|
||||
hdrs = [
|
||||
"tensorhandle_convertible.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/c/eager:immediate_execution_tensor_handle",
|
||||
],
|
||||
)
|
@ -0,0 +1,46 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/core/common_runtime/eager/tensor_handle.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
Constant::Constant(ImmediateTensorHandlePtr handle)
|
||||
: TensorHandleConvertible(std::move(handle)) {}
|
||||
|
||||
Status Constant::Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
ImmediateExecutionTensorHandle* handle = ctx->CreateLocalHandle(tensor);
|
||||
if (handle == nullptr) {
|
||||
return errors::Internal("Failed to convert tensor to tensorhandle");
|
||||
}
|
||||
output->reset(new Constant(ImmediateTensorHandlePtr(handle)));
|
||||
return Status();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -0,0 +1,55 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/tensorhandle_convertible.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// This class corresponds to python's tf.constant, which is effectively a
|
||||
// TensorHandle explicitly initialized to some value.
|
||||
// For now this doesn't do much beyond wrap Context's CreateLocalHandle method,
|
||||
// and offer a subclass of TensorHandleConvertible. Note that similar to
|
||||
// the python's eager mode logic, we bypass calling the "Const" op:
|
||||
// https://github.com/tensorflow/tensorflow/blob/1c064ab76064c58e54261b805027474885a1534d/tensorflow/python/framework/constant_op.py#L301
|
||||
class Constant : public TensorHandleConvertible {
|
||||
public:
|
||||
static Status Create(ImmediateExecutionContext* ctx,
|
||||
AbstractTensorInterface* tensor,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
// RevivedConstant is movable, but not copyable.
|
||||
Constant(Constant&& other) = default;
|
||||
Constant& operator=(Constant&& other) = default;
|
||||
|
||||
~Constant() override = default;
|
||||
|
||||
private:
|
||||
explicit Constant(ImmediateTensorHandlePtr handle);
|
||||
Constant(const Constant&) = delete;
|
||||
Constant& operator=(const Constant&) = delete;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_REVIVED_CONSTANT_H_
|
@ -0,0 +1,49 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// A common interface for objects that can be converted to a TensorHandle.
|
||||
// Examples of objects that implement this include Variables, Constants, Assets,
|
||||
// etc. This is used to convert captured objects into a ConcreteFunction's
|
||||
// captured TensorHandles:
|
||||
// https://github.com/tensorflow/tensorflow/blob/676a68963ea4b64fe479b9cede06aa8f5b290ab8/tensorflow/python/saved_model/load.py#L229-L240
|
||||
class TensorHandleConvertible {
|
||||
public:
|
||||
explicit TensorHandleConvertible(ImmediateTensorHandlePtr handle)
|
||||
: handle_(std::move(handle)) {}
|
||||
|
||||
ImmediateExecutionTensorHandle* handle() { return handle_.get(); }
|
||||
|
||||
// TensorHandleConvertible is movable, but not copyable.
|
||||
TensorHandleConvertible(TensorHandleConvertible&& other) = default;
|
||||
TensorHandleConvertible& operator=(TensorHandleConvertible&& other) = default;
|
||||
|
||||
virtual ~TensorHandleConvertible() = default;
|
||||
|
||||
protected:
|
||||
TensorHandleConvertible(const TensorHandleConvertible&) = delete;
|
||||
TensorHandleConvertible& operator=(const TensorHandleConvertible&) = delete;
|
||||
ImmediateTensorHandlePtr handle_;
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_TENSORHANDLE_CONVERTIBLE_H_
|
@ -0,0 +1,38 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/tf_tensor_internal.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output) {
|
||||
tensorflow::Tensor tensor;
|
||||
bool parse_result = tensor.FromProto(proto);
|
||||
if (!parse_result) {
|
||||
return errors::Internal("Failed to parse tensor from tensorproto");
|
||||
}
|
||||
|
||||
TensorInterface tensor_interface(std::move(tensor));
|
||||
return Constant::Create(ctx, &tensor_interface, output);
|
||||
}
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
#define TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
||||
|
||||
// Some internal utility functions for the SavedModelAPI, factored out into a
|
||||
// separately unit-testable header.
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_context.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace internal {
|
||||
|
||||
// Load a TensorProto into a tensorflow::Constant. This is similar to the
|
||||
// constant loading logic in python:
|
||||
// https://github.com/tensorflow/tensorflow/blob/516608035f85cec8b126712b0ff8407220206b22/tensorflow/python/saved_model/load.py#L437
|
||||
Status TensorProtoToConstant(ImmediateExecutionContext* ctx,
|
||||
const TensorProto& proto,
|
||||
std::unique_ptr<Constant>* output);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EXPERIMENTAL_SAVED_MODEL_CORE_SAVED_MODEL_UTILS_H_
|
@ -0,0 +1,199 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/c/experimental/saved_model/core/saved_model_utils.h"
|
||||
|
||||
#include <string.h>
|
||||
|
||||
#include <memory>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/c/eager/immediate_execution_tensor_handle.h"
|
||||
#include "tensorflow/c/experimental/saved_model/core/revived_types/constant.h"
|
||||
#include "tensorflow/c/tensor_interface.h"
|
||||
#include "tensorflow/core/common_runtime/device_mgr.h"
|
||||
#include "tensorflow/core/common_runtime/eager/context.h"
|
||||
#include "tensorflow/core/framework/numeric_types.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/framework/tensor.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/framework/types.pb.h"
|
||||
#include "tensorflow/core/lib/bfloat16/bfloat16.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace {
|
||||
|
||||
// Converts a tensorflow::DatatypeSet to std::vector<DataType>.
|
||||
// This is needed for GTest's ::testing::ValuesIn, since
|
||||
// DataTypeSet doesn't fullfill all the constraints of an STL-like iterable.
|
||||
std::vector<DataType> DataTypeSetToVector(DataTypeSet set) {
|
||||
std::vector<DataType> result;
|
||||
result.reserve(set.size());
|
||||
for (DataType dt : set) {
|
||||
result.push_back(dt);
|
||||
}
|
||||
return result;
|
||||
}
|
||||
|
||||
// Returns a vector of shapes intended to be "interesting" test cases.
|
||||
std::vector<std::vector<int64>> InterestingShapes() {
|
||||
std::vector<std::vector<int64>> interesting_shapes;
|
||||
interesting_shapes.push_back({}); // Scalar
|
||||
interesting_shapes.push_back({10}); // 1D Vector
|
||||
interesting_shapes.push_back({3, 3}); // 2D Matrix
|
||||
interesting_shapes.push_back({1, 4, 6, 10}); // Higher Dimension Tensor
|
||||
return interesting_shapes;
|
||||
}
|
||||
|
||||
// Fills a numeric tensor with `value`.
|
||||
void FillNumericTensor(Tensor* tensor, int8 value) {
|
||||
switch (tensor->dtype()) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
const auto& flattened = tensor->flat<type>(); \
|
||||
for (int i = 0; i < tensor->NumElements(); ++i) { \
|
||||
flattened(i) = value; \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: "
|
||||
<< DataTypeString(tensor->dtype());
|
||||
break;
|
||||
}
|
||||
}
|
||||
|
||||
// Checks the underlying data is equal for the buffers for two numeric tensors.
|
||||
// Note: The caller must ensure to check that the dtypes and sizes of the
|
||||
// underlying buffers are the same before calling this.
|
||||
void CheckBufferDataIsEqual(DataType dtype, int64 num_elements, void* a,
|
||||
void* b) {
|
||||
switch (dtype) {
|
||||
#define CASE(type) \
|
||||
case DataTypeToEnum<type>::value: { \
|
||||
type* typed_a = static_cast<type*>(a); \
|
||||
type* typed_b = static_cast<type*>(b); \
|
||||
for (int64 i = 0; i < num_elements; ++i) { \
|
||||
if (DataTypeIsFloating(dtype)) { \
|
||||
EXPECT_FLOAT_EQ(typed_a[i], typed_b[i]); \
|
||||
} else { \
|
||||
EXPECT_EQ(typed_a[i], typed_b[i]); \
|
||||
} \
|
||||
} \
|
||||
break; \
|
||||
}
|
||||
TF_CALL_INTEGRAL_TYPES(CASE);
|
||||
TF_CALL_double(CASE);
|
||||
TF_CALL_float(CASE);
|
||||
#undef CASE
|
||||
default:
|
||||
CHECK(false) << "Unsupported data type: " << DataTypeString(dtype);
|
||||
}
|
||||
}
|
||||
|
||||
class ConstantTest : public ::testing::TestWithParam<
|
||||
std::tuple<DataType, std::vector<int64>, bool>> {
|
||||
public:
|
||||
ConstantTest()
|
||||
: device_mgr_(std::make_unique<StaticDeviceMgr>(DeviceFactory::NewDevice(
|
||||
"CPU", {}, "/job:localhost/replica:0/task:0"))),
|
||||
ctx_(new EagerContext(
|
||||
SessionOptions(),
|
||||
tensorflow::ContextDevicePlacementPolicy::DEVICE_PLACEMENT_SILENT,
|
||||
tensorflow::ContextMirroringPolicy::MIRRORING_NONE,
|
||||
/* async= */ false,
|
||||
/* lazy_copy_function_remote_inputs= */ false, device_mgr_.get(),
|
||||
/* device_mgr_owned= */ false, /* rendezvous= */ nullptr,
|
||||
/* custom_kernel_creator= */ nullptr,
|
||||
/* cluster_flr= */ nullptr)) {}
|
||||
|
||||
EagerContext* context() { return ctx_.get(); }
|
||||
|
||||
private:
|
||||
std::unique_ptr<StaticDeviceMgr> device_mgr_;
|
||||
EagerContextPtr ctx_;
|
||||
};
|
||||
|
||||
// Basic sanity check that roundtripping a Tensor->Tensorproto->Constant
|
||||
// preserves values.
|
||||
TEST_P(ConstantTest, CreateConstantSuccessful) {
|
||||
// Get test parameters
|
||||
auto& test_params = GetParam();
|
||||
DataType dtype = std::get<0>(test_params);
|
||||
TensorShape shape(std::get<1>(test_params));
|
||||
bool tensorproto_use_tensor_content = std::get<2>(test_params);
|
||||
|
||||
// Construct a Tensor with the given dtype + shape
|
||||
Tensor expected(dtype, shape);
|
||||
FillNumericTensor(&expected, 42);
|
||||
|
||||
// Serialize it to a Tensorproto
|
||||
TensorProto proto;
|
||||
if (tensorproto_use_tensor_content) {
|
||||
expected.AsProtoTensorContent(&proto);
|
||||
} else {
|
||||
expected.AsProtoField(&proto);
|
||||
}
|
||||
|
||||
// Revival should succeed w/o errors
|
||||
std::unique_ptr<Constant> revived;
|
||||
TF_EXPECT_OK(internal::TensorProtoToConstant(context(), proto, &revived));
|
||||
|
||||
// The revived tensorhandle should have the exact same dtype, shape, +
|
||||
// approx equivalent data to the original.
|
||||
ImmediateExecutionTensorHandle* handle = revived->handle();
|
||||
Status status;
|
||||
AbstractTensorPtr revived_tensor(handle->Resolve(&status));
|
||||
TF_EXPECT_OK(status) << "Failed to convert tensorhandle to tensor";
|
||||
EXPECT_EQ(revived_tensor->Type(), expected.dtype());
|
||||
EXPECT_EQ(revived_tensor->NumElements(), expected.NumElements());
|
||||
EXPECT_EQ(revived_tensor->NumDims(), expected.dims());
|
||||
for (int i = 0; i < expected.dims(); ++i) {
|
||||
EXPECT_EQ(revived_tensor->Dim(i), expected.dim_size(i));
|
||||
}
|
||||
|
||||
CheckBufferDataIsEqual(expected.dtype(), expected.NumElements(),
|
||||
revived_tensor->Data(), expected.data());
|
||||
}
|
||||
|
||||
// Test against combinations of tensors that are
|
||||
// 1. Varying dtypes
|
||||
// 2. Varying shapes
|
||||
// 3. TensorProto serialized using tensor_content vs repeated type
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantIntegerDtypesTest, ConstantTest,
|
||||
::testing::Combine(
|
||||
::testing::ValuesIn(DataTypeSetToVector(kDataTypeIsInteger)),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
ConstantFloatingDtypesTest, ConstantTest,
|
||||
::testing::Combine(::testing::Values(DT_FLOAT, DT_DOUBLE),
|
||||
::testing::ValuesIn(InterestingShapes()),
|
||||
::testing::Values(false, true)));
|
||||
|
||||
} // namespace
|
||||
} // namespace tensorflow
|
@ -69,6 +69,7 @@ cc_library(
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"@llvm-project//llvm:ARMCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:PowerPCCodeGen", # fixdeps: keep
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//llvm:Target",
|
||||
"@llvm-project//llvm:X86CodeGen", # fixdeps: keep
|
||||
"//tensorflow/core:regexp_internal",
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "llvm-c/Target.h"
|
||||
#include "llvm/Support/ManagedStatic.h"
|
||||
#include "tensorflow/compiler/aot/codegen.h"
|
||||
#include "tensorflow/compiler/aot/flags.h"
|
||||
#include "tensorflow/compiler/aot/quantize.h"
|
||||
|
@ -476,10 +476,36 @@ Status XlaComputationLaunchContext::PopulateOutputs(
|
||||
stream->ThenRecordEvent(definition_event.get());
|
||||
}
|
||||
|
||||
std::vector<TensorShape> output_tensor_shapes;
|
||||
output_tensor_shapes.reserve(ctx->num_outputs());
|
||||
if (output.on_host_shape().is_dynamic()) {
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
|
||||
xla::Shape output_host_shape = output.on_host_shape();
|
||||
xla::Shape output_device_shape = output.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, &output, &output_host_shape, &output_device_shape));
|
||||
|
||||
output.set_shapes(output_host_shape, output_device_shape);
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const xla::Shape& subshape =
|
||||
xla::ShapeUtil::GetSubshape(output_host_shape, {i});
|
||||
TensorShape shape;
|
||||
TF_RETURN_IF_ERROR(XLAShapeToTensorShape(subshape, &shape));
|
||||
output_tensor_shapes.push_back(shape);
|
||||
}
|
||||
} else {
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
output_tensor_shapes.push_back(compilation_result->outputs[i].shape);
|
||||
}
|
||||
}
|
||||
|
||||
// Copy XLA results to the OpOutputList.
|
||||
int output_num = 0;
|
||||
for (int i = 0; i < ctx->num_outputs(); ++i) {
|
||||
const TensorShape& shape = compilation_result->outputs[i].shape;
|
||||
const TensorShape& shape = output_tensor_shapes[i];
|
||||
const DataType& type = compilation_result->outputs[i].type;
|
||||
VLOG(2) << "Retval " << i << " shape " << shape.DebugString() << " type "
|
||||
<< DataTypeString(type);
|
||||
|
@ -24,10 +24,10 @@ the codegen input.
|
||||
|
||||
## Tasks
|
||||
|
||||
| Host | Device
|
||||
------------- | ------------------------ | ------------------------
|
||||
Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1)
|
||||
Output format | xla::Thunk (Task 2) | LLVM IR (Task 3)
|
||||
| | Host | Device
|
||||
| ------------- | ------------------------ | ------------------------
|
||||
| Input format | HloInstruction* (Task 1) | HloInstruction* (Task 1)
|
||||
| Output format | xla::Thunk (Task 2) | LLVM IR (Task 3)
|
||||
|
||||
* **Task 1** changes both host and device input format from HloInstruction* to
|
||||
LHLO.
|
||||
|
@ -26,6 +26,7 @@ filegroup(
|
||||
"//tensorflow/compiler/mlir/lite:flatbuffer_to_string",
|
||||
"//tensorflow/compiler/mlir/lite:tf_tfl_translate",
|
||||
"@llvm-project//llvm:FileCheck",
|
||||
"@llvm-project//llvm:not",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -0,0 +1,257 @@
|
||||
# RUN: not tf_tfl_translate -tf-upgrade-legacy=false -tf-input-arrays=Placeholder,Placeholder_1 -tf-input-shapes=1,2:1 -tf-output-arrays=cond/Merge -tf-enable-shape-inference-on-import=false -mlir-print-debuginfo -output-mlir %s -o - 2>&1 | FileCheck %s
|
||||
|
||||
# CHECK: error: The graph has Control Flow V1 ops. TFLite converter doesn't support Control Flow V1 ops. Consider using Control Flow V2 ops instead.
|
||||
|
||||
node {
|
||||
name: "Const"
|
||||
op: "Const"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "value"
|
||||
value {
|
||||
tensor {
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape {
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
tensor_content: "\315\314\314=\315\314L>\232\231\231>\315\314\314>"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
dim {
|
||||
size: -1
|
||||
}
|
||||
dim {
|
||||
size: 2
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "Placeholder_1"
|
||||
op: "Placeholder"
|
||||
attr {
|
||||
key: "dtype"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "shape"
|
||||
value {
|
||||
shape {
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder_1"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/switch_t"
|
||||
op: "Identity"
|
||||
input: "cond/Switch:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/switch_f"
|
||||
op: "Identity"
|
||||
input: "cond/Switch"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/pred_id"
|
||||
op: "Identity"
|
||||
input: "Placeholder_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_BOOL
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul"
|
||||
op: "MatMul"
|
||||
input: "cond/MatMul/Switch:1"
|
||||
input: "cond/MatMul/Switch_1:1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "transpose_a"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "transpose_b"
|
||||
value {
|
||||
b: false
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Placeholder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/MatMul/Switch_1"
|
||||
op: "Switch"
|
||||
input: "Const"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add"
|
||||
op: "Add"
|
||||
input: "cond/Add/Switch"
|
||||
input: "cond/Add/Switch_1"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add/Switch"
|
||||
op: "Switch"
|
||||
input: "Placeholder"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Placeholder"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Add/Switch_1"
|
||||
op: "Switch"
|
||||
input: "Const"
|
||||
input: "cond/pred_id"
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "_class"
|
||||
value {
|
||||
list {
|
||||
s: "loc:@Const"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "cond/Merge"
|
||||
op: "Merge"
|
||||
input: "cond/Add"
|
||||
input: "cond/MatMul"
|
||||
attr {
|
||||
key: "N"
|
||||
value {
|
||||
i: 2
|
||||
}
|
||||
}
|
||||
attr {
|
||||
key: "T"
|
||||
value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
node {
|
||||
name: "init"
|
||||
op: "NoOp"
|
||||
}
|
||||
versions {
|
||||
producer: 134
|
||||
}
|
@ -172,7 +172,7 @@ int main(int argc, char **argv) {
|
||||
input_file_name, input_mlir, use_splatted_constant, custom_opdefs,
|
||||
debug_info_file, input_arrays, input_dtypes, input_shapes,
|
||||
output_arrays,
|
||||
/*prune_unused_nodes=*/true, &source_mgr, &context);
|
||||
/*prune_unused_nodes=*/true, upgrade_legacy, &source_mgr, &context);
|
||||
}
|
||||
|
||||
// If errors occur, the library call in the above already logged the error
|
||||
|
@ -21,6 +21,7 @@ limitations under the License.
|
||||
#include "llvm/Support/raw_ostream.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "mlir/IR/Visitors.h" // from @llvm-project
|
||||
#include "mlir/Parser.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/FileUtilities.h" // from @llvm-project
|
||||
@ -28,6 +29,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/mlir/lite/flatbuffer_export.h"
|
||||
#include "tensorflow/compiler/mlir/lite/quantization/quantization_config.h"
|
||||
#include "tensorflow/compiler/mlir/lite/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_executor.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/decode_constant.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/transforms/passes.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/translate/tf_mlir_translate.h"
|
||||
@ -39,19 +41,47 @@ limitations under the License.
|
||||
#include "tensorflow/stream_executor/lib/statusor.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
namespace {
|
||||
using mlir::MLIRContext;
|
||||
using mlir::ModuleOp;
|
||||
using mlir::Operation;
|
||||
using mlir::OwningModuleRef;
|
||||
using stream_executor::port::StatusOr;
|
||||
|
||||
bool IsControlFlowV1Op(Operation* op) {
|
||||
return mlir::isa<mlir::tf_executor::SwitchOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::MergeOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::EnterOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::ExitOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSinkOp>(op) ||
|
||||
mlir::isa<mlir::tf_executor::NextIterationSourceOp>(op);
|
||||
}
|
||||
|
||||
mlir::LogicalResult IsValidGraph(mlir::ModuleOp module) {
|
||||
auto result = module.walk([&](Operation* op) {
|
||||
return IsControlFlowV1Op(op) ? mlir::WalkResult::interrupt()
|
||||
: mlir::WalkResult::advance();
|
||||
});
|
||||
if (result.wasInterrupted()) {
|
||||
module.emitError(
|
||||
"The graph has Control Flow V1 ops. TFLite converter doesn't support "
|
||||
"Control Flow V1 ops. Consider using Control Flow V2 ops instead. See "
|
||||
"https://www.tensorflow.org/api_docs/python/tf/compat/v1/"
|
||||
"enable_control_flow_v2.");
|
||||
return mlir::failure();
|
||||
}
|
||||
return mlir::success();
|
||||
}
|
||||
} // namespace
|
||||
|
||||
StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
const std::string& input_filename, bool input_mlir,
|
||||
bool use_splatted_constant, const std::vector<std::string>& extra_tf_opdefs,
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, MLIRContext* context) {
|
||||
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
|
||||
MLIRContext* context) {
|
||||
// Set up the input file.
|
||||
std::string error_message;
|
||||
auto file = mlir::openInputFile(input_filename, &error_message);
|
||||
@ -86,14 +116,14 @@ StatusOr<OwningModuleRef> LoadFromGraphdefOrMlirSource(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
/*graph_as_function=*/false, enable_upgrade_legacy,
|
||||
/*enable_shape_inference=*/false, context);
|
||||
}
|
||||
return tensorflow::GraphdefToMlirTranslateFunction(
|
||||
file->getBuffer(), debug_info_file, input_arrays, input_dtypes,
|
||||
input_shapes, output_arrays, /*control_output_arrays=*/"",
|
||||
prune_unused_nodes, /*convert_legacy_fed_inputs=*/true,
|
||||
/*graph_as_function=*/false, /*upgrade_legacy=*/true,
|
||||
/*graph_as_function=*/false, enable_upgrade_legacy,
|
||||
/*enable_shape_inference=*/false, context);
|
||||
}
|
||||
|
||||
@ -104,7 +134,8 @@ Status ConvertTFExecutorToTFLOrFlatbuffer(
|
||||
mlir::PassManager* pass_manager) {
|
||||
mlir::StatusScopedDiagnosticHandler statusHandler(module.getContext(),
|
||||
/*propagate=*/true);
|
||||
if (failed(pass_manager->run(module))) {
|
||||
|
||||
if (failed(IsValidGraph(module)) || failed(pass_manager->run(module))) {
|
||||
return statusHandler.ConsumeStatus();
|
||||
}
|
||||
|
||||
|
@ -41,7 +41,8 @@ LoadFromGraphdefOrMlirSource(
|
||||
absl::string_view debug_info_file, absl::string_view input_arrays,
|
||||
absl::string_view input_dtypes, absl::string_view input_shapes,
|
||||
absl::string_view output_arrays, bool prune_unused_nodes,
|
||||
llvm::SourceMgr* source_mgr, mlir::MLIRContext* context);
|
||||
bool enable_upgrade_legacy, llvm::SourceMgr* source_mgr,
|
||||
mlir::MLIRContext* context);
|
||||
|
||||
// Load Saved model (either v1 or v2) into MLIR.
|
||||
stream_executor::port::StatusOr<mlir::OwningModuleRef> ImportSavedModel(
|
||||
|
@ -1356,6 +1356,7 @@ cc_library(
|
||||
srcs = ["utils/tpu_rewrite_device_util.cc"],
|
||||
hdrs = ["utils/tpu_rewrite_device_util.h"],
|
||||
deps = [
|
||||
":tensorflow",
|
||||
"//tensorflow/compiler/xla:array4d",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:computation_placer",
|
||||
@ -1366,6 +1367,7 @@ cc_library(
|
||||
"@com_google_absl//absl/strings",
|
||||
"@llvm-project//llvm:Support",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:Support",
|
||||
],
|
||||
)
|
||||
|
||||
@ -1374,6 +1376,7 @@ tf_cc_test(
|
||||
size = "small",
|
||||
srcs = ["utils/tpu_rewrite_device_util_test.cc"],
|
||||
deps = [
|
||||
":device_util",
|
||||
":tpu_rewrite_device_util",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:test",
|
||||
|
@ -4,8 +4,6 @@ load("//tensorflow/compiler/mlir:glob_lit_test.bzl", "lit_test")
|
||||
|
||||
def tf_saved_model_test(name, data, tags = None):
|
||||
"""Create a SavedModel test."""
|
||||
if tags == None:
|
||||
tags = ["no_rocm"]
|
||||
native.py_binary(
|
||||
name = name,
|
||||
testonly = 1,
|
||||
@ -26,5 +24,5 @@ def tf_saved_model_test(name, data, tags = None):
|
||||
name = name + ".py",
|
||||
data = [name] + data,
|
||||
driver = "@llvm-project//mlir:run_lit.sh",
|
||||
tags = tags,
|
||||
tags = tags + ["no_rocm"],
|
||||
)
|
||||
|
@ -113,64 +113,6 @@ tf_device::LaunchOp CreateLaunchForBlock(OpBuilder* builder, Operation* op,
|
||||
return launch;
|
||||
}
|
||||
|
||||
// Parses TPU compilation and execution devices from a TPU cluster and returns
|
||||
// the host device for the head and tail computations. If the TPU computation is
|
||||
// replicated, kTPUReplicatedHost is returned instead.
|
||||
LogicalResult GetHostDeviceForHeadTailComputation(
|
||||
mlir::TF::RuntimeDevices devices, tf_device::ClusterOp cluster,
|
||||
std::string* host_device) {
|
||||
auto replicate = cluster.getParentOfType<tf_device::ReplicateOp>();
|
||||
if (replicate) {
|
||||
*host_device = tensorflow::kTPUReplicatedHost;
|
||||
return success();
|
||||
}
|
||||
|
||||
auto num_cores_per_replica_attr =
|
||||
cluster.getAttrOfType<IntegerAttr>(tensorflow::kNumCoresPerReplicaAttr);
|
||||
if (!num_cores_per_replica_attr)
|
||||
return cluster.emitOpError(
|
||||
"cluster op missing `num_cores_per_replica` attribute");
|
||||
|
||||
if (num_cores_per_replica_attr.getInt() != 1)
|
||||
return cluster.emitOpError(
|
||||
"outside compilation is not supported with model parallelism.");
|
||||
|
||||
auto topology_attr =
|
||||
cluster.getAttrOfType<StringAttr>(tensorflow::kTopologyAttr);
|
||||
if (!topology_attr)
|
||||
return cluster.emitOpError("cluster op missing `topology` attribute");
|
||||
|
||||
auto device_assignment_attr =
|
||||
cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr);
|
||||
if (!device_assignment_attr)
|
||||
return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
|
||||
tensorflow::kDeviceAssignmentAttr)
|
||||
.str());
|
||||
|
||||
auto status_or_device_coodinates =
|
||||
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
||||
|
||||
if (!status_or_device_coodinates.ok())
|
||||
return cluster.emitError()
|
||||
<< "error in fetching tpu device coordinates: "
|
||||
<< status_or_device_coodinates.status().error_message();
|
||||
|
||||
// Determine compilation and execution devices.
|
||||
auto status_or_tpu_device_assignment =
|
||||
tensorflow::GetTPUCompilationAndExecutionDevices(
|
||||
devices.device_names(), /*num_replicas=*/1,
|
||||
/*num_cores_per_replica=*/1, topology_attr.getValue(),
|
||||
status_or_device_coodinates.ConsumeValueOrDie());
|
||||
if (!status_or_tpu_device_assignment.ok())
|
||||
return cluster.emitError()
|
||||
<< "error in fetching TPU compilation/execution devices: "
|
||||
<< status_or_tpu_device_assignment.status().error_message();
|
||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||
|
||||
*host_device = tpu_device_assignment.tpu_devices[0][0].host;
|
||||
return success();
|
||||
}
|
||||
|
||||
// Returns a set of ops that are outside compiled and can be extracted to before
|
||||
// the TPU computation. These ops are either connected to the inputs of the TPU
|
||||
// computation or other ops that can be extracted, and have no operands from
|
||||
@ -232,8 +174,8 @@ mlir::LogicalResult LiftHeadOutsideCompiledOps(
|
||||
llvm::SmallVector<Operation*, 4> head_outside_compiled_ops =
|
||||
FindOutsideCompiledOpsAtHead(cluster);
|
||||
if (head_outside_compiled_ops.empty()) return success();
|
||||
if (failed(
|
||||
GetHostDeviceForHeadTailComputation(devices, cluster, host_device)))
|
||||
if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, cluster,
|
||||
host_device)))
|
||||
return failure();
|
||||
|
||||
CreateHeadComputation(builder, cluster, head_outside_compiled_ops,
|
||||
@ -361,8 +303,8 @@ mlir::LogicalResult LiftTailOutsideCompiledOps(
|
||||
if (tail_outside_compiled_ops.empty()) return success();
|
||||
|
||||
if (host_device.empty())
|
||||
if (failed(GetHostDeviceForHeadTailComputation(devices, *cluster,
|
||||
&host_device)))
|
||||
if (failed(tensorflow::GetHostDeviceOutsideComputation(devices, *cluster,
|
||||
&host_device)))
|
||||
return failure();
|
||||
|
||||
// Forward all results of cluster first. These results will be remapped once
|
||||
|
@ -484,4 +484,59 @@ std::string GetDeviceAliasForLogicalCore(int core_index) {
|
||||
return llvm::formatv("{0}_{1}", kTPUReplicatedCore, core_index).str();
|
||||
}
|
||||
|
||||
mlir::LogicalResult GetHostDeviceOutsideComputation(
|
||||
mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
|
||||
std::string* host_device) {
|
||||
auto replicate = cluster.getParentOfType<mlir::tf_device::ReplicateOp>();
|
||||
if (replicate) {
|
||||
*host_device = tensorflow::kTPUReplicatedHost;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
auto num_cores_per_replica_attr = cluster.getAttrOfType<mlir::IntegerAttr>(
|
||||
tensorflow::kNumCoresPerReplicaAttr);
|
||||
if (!num_cores_per_replica_attr)
|
||||
return cluster.emitOpError(
|
||||
"cluster op missing `num_cores_per_replica` attribute");
|
||||
|
||||
if (num_cores_per_replica_attr.getInt() != 1)
|
||||
return cluster.emitOpError(
|
||||
"outside compilation is not supported with model parallelism.");
|
||||
|
||||
auto topology_attr =
|
||||
cluster.getAttrOfType<mlir::StringAttr>(tensorflow::kTopologyAttr);
|
||||
if (!topology_attr)
|
||||
return cluster.emitOpError("cluster op missing `topology` attribute");
|
||||
|
||||
auto device_assignment_attr =
|
||||
cluster.getAttrOfType<mlir::ArrayAttr>(tensorflow::kDeviceAssignmentAttr);
|
||||
if (!device_assignment_attr)
|
||||
return cluster.emitOpError(llvm::formatv("requires attribute '{0}'",
|
||||
tensorflow::kDeviceAssignmentAttr)
|
||||
.str());
|
||||
|
||||
auto status_or_device_coodinates =
|
||||
tensorflow::GetDeviceCoordinates(device_assignment_attr);
|
||||
|
||||
if (!status_or_device_coodinates.ok())
|
||||
return cluster.emitError()
|
||||
<< "error in fetching tpu device coordinates: "
|
||||
<< status_or_device_coodinates.status().error_message();
|
||||
|
||||
// Determine compilation and execution devices.
|
||||
auto status_or_tpu_device_assignment =
|
||||
tensorflow::GetTPUCompilationAndExecutionDevices(
|
||||
devices.device_names(), /*num_replicas=*/1,
|
||||
/*num_cores_per_replica=*/1, topology_attr.getValue(),
|
||||
status_or_device_coodinates.ConsumeValueOrDie());
|
||||
if (!status_or_tpu_device_assignment.ok())
|
||||
return cluster.emitError()
|
||||
<< "error in fetching TPU compilation/execution devices: "
|
||||
<< status_or_tpu_device_assignment.status().error_message();
|
||||
auto& tpu_device_assignment = status_or_tpu_device_assignment.ValueOrDie();
|
||||
|
||||
*host_device = tpu_device_assignment.tpu_devices[0][0].host;
|
||||
return mlir::success();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
@ -23,6 +23,9 @@ limitations under the License.
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/StringRef.h"
|
||||
#include "mlir/IR/Attributes.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_device.h"
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_structs.h"
|
||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/util/device_name_utils.h"
|
||||
@ -237,6 +240,13 @@ StatusOr<TPUDeviceAssignment> GetTPUCompilationAndExecutionDevices(
|
||||
// logical core.
|
||||
std::string GetDeviceAliasForLogicalCore(int core_index);
|
||||
|
||||
// Parses TPU compilation and execution devices from a TPU cluster and returns
|
||||
// the host device for the head and tail computations. If the TPU computation is
|
||||
// replicated, kTPUReplicatedHost is returned instead.
|
||||
mlir::LogicalResult GetHostDeviceOutsideComputation(
|
||||
mlir::TF::RuntimeDevices devices, mlir::tf_device::ClusterOp cluster,
|
||||
std::string* host_device);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_TENSORFLOW_UTILS_TPU_REWRITE_DEVICE_UTIL_H_
|
||||
|
@ -21,6 +21,8 @@ limitations under the License.
|
||||
#include "llvm/Support/FormatVariadic.h"
|
||||
#include "mlir/IR/Builders.h" // from @llvm-project
|
||||
#include "mlir/IR/MLIRContext.h" // from @llvm-project
|
||||
#include "mlir/IR/Module.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/protobuf/tpu/topology.pb.h"
|
||||
@ -622,5 +624,185 @@ TEST(TPURewriteDeviceUtilTest, TestInvalidAttrForDeviceAssignmentDisallowed) {
|
||||
"bad 'device_assignment' attribute at index 0, not an int");
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostFailDeviceMissingAttributes) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
|
||||
mlir::TF::RuntimeDevices devices;
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::failed(
|
||||
GetHostDeviceOutsideComputation(devices, cluster, &host_device)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailModelParallelism) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 5));
|
||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::failed(
|
||||
GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingTopology) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::failed(
|
||||
GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailMissingDeviceAssignment) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::failed(
|
||||
GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceAssignment) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||
cluster.setAttr(kDeviceAssignmentAttr,
|
||||
builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
|
||||
{"bad_device_assigment"})));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::failed(
|
||||
GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceFailBadDeviceName) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
module_ref->setAttr(
|
||||
"tf.devices", builder.getStrArrayAttr(
|
||||
llvm::ArrayRef<llvm::StringRef>({"bad_device_name"})));
|
||||
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
GetDevicesFromOp(*module_ref, &runtime_devices);
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::failed(
|
||||
GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceTPUReplicate) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
|
||||
llvm::SmallDenseMap<llvm::StringRef, llvm::SmallVector<llvm::StringRef, 4>>
|
||||
devices;
|
||||
auto replicate = builder.create<mlir::tf_device::ReplicateOp>(
|
||||
mlir::UnknownLoc::get(&context), /*num_replicas=*/2, devices,
|
||||
llvm::ArrayRef<std::pair<llvm::ArrayRef<mlir::Value>, mlir::Type>>{},
|
||||
llvm::ArrayRef<mlir::Type>{});
|
||||
builder.setInsertionPoint(&replicate.body().front(),
|
||||
replicate.body().front().begin());
|
||||
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::succeeded(
|
||||
GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
|
||||
EXPECT_EQ(host_device, kTPUReplicatedHost);
|
||||
}
|
||||
|
||||
TEST(TPURewriteDeviceUtilTest, TestGetHostDeviceNotReplicated) {
|
||||
mlir::registerDialect<mlir::tf_device::TensorFlowDeviceDialect>();
|
||||
mlir::MLIRContext context;
|
||||
mlir::OwningModuleRef module_ref =
|
||||
mlir::ModuleOp::create(mlir::UnknownLoc::get(&context));
|
||||
mlir::OpBuilder builder(module_ref->getBodyRegion());
|
||||
module_ref->setAttr(
|
||||
"tf.devices", builder.getStrArrayAttr(llvm::ArrayRef<llvm::StringRef>(
|
||||
{"/job:localhost/replica:0/task:0/device:TPU_SYSTEM:0",
|
||||
"/job:localhost/replica:0/task:0/device:TPU:0",
|
||||
"/job:worker/replica:0/task:0/device:CPU:0"})));
|
||||
|
||||
llvm::SmallVector<mlir::Type, 8> result_types;
|
||||
auto cluster = builder.create<mlir::tf_device::ClusterOp>(
|
||||
mlir::UnknownLoc::get(&context), result_types);
|
||||
cluster.setAttr(kNumCoresPerReplicaAttr,
|
||||
builder.getIntegerAttr(builder.getIntegerType(64), 1));
|
||||
cluster.setAttr(kTopologyAttr, builder.getStringAttr(""));
|
||||
cluster.setAttr(kDeviceAssignmentAttr, builder.getArrayAttr({}));
|
||||
|
||||
mlir::TF::RuntimeDevices runtime_devices;
|
||||
GetDevicesFromOp(*module_ref, &runtime_devices);
|
||||
std::string host_device;
|
||||
EXPECT_TRUE(mlir::succeeded(
|
||||
GetHostDeviceOutsideComputation(runtime_devices, cluster, &host_device)));
|
||||
EXPECT_EQ(host_device, "/job:localhost/replica:0/task:0/device:CPU:0");
|
||||
}
|
||||
|
||||
} // anonymous namespace
|
||||
} // namespace tensorflow
|
||||
|
@ -129,7 +129,7 @@ struct DynamicMemRefCastOpConverter
|
||||
void PopulateLhloToLLVMConversionPatterns(LLVMTypeConverter *converter,
|
||||
OwningRewritePatternList *patterns) {
|
||||
patterns->insert<DynamicMemRefCastOpConverter, StaticMemRefCastOpConverter>(
|
||||
*converter, LowerToLLVMOptions());
|
||||
*converter);
|
||||
}
|
||||
|
||||
} // namespace xla_lhlo
|
||||
|
@ -229,16 +229,16 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
self._testBinary(
|
||||
gen_math_ops.xdivy,
|
||||
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
|
||||
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
|
||||
expected=np.array([0, 0.8, 0.5, 0.285714, 0.125, 0], dtype=dtype),
|
||||
np.array([[0, 5, 6, 7, 8, float("NaN")]], dtype=dtype),
|
||||
expected=np.array([[0, 0.8, 0.5, 0.285714, 0.125, 0]], dtype=dtype),
|
||||
rtol=1e-6,
|
||||
atol=1e-6)
|
||||
|
||||
self._testBinary(
|
||||
gen_math_ops.xlogy,
|
||||
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
|
||||
np.array([0, 5, 6, 7, 8, float("NaN")], dtype=dtype),
|
||||
expected=np.array([0, 6.437752, 5.375278, 3.89182, 2.079442, 0],
|
||||
np.array([[0, 5, 6, 7, 8, float("NaN")]], dtype=dtype),
|
||||
expected=np.array([[0, 6.437752, 5.375278, 3.89182, 2.079442, 0]],
|
||||
dtype=dtype),
|
||||
rtol=1e-4,
|
||||
atol=1e-6)
|
||||
@ -246,8 +246,8 @@ class BinaryOpsTest(xla_test.XLATestCase):
|
||||
self._testBinary(
|
||||
gen_math_ops.xlog1py,
|
||||
np.array([0, 4, 3, 2, 1, 0], dtype=dtype),
|
||||
np.array([-1, 5, 6, 7, 8, float("NaN")], dtype=dtype),
|
||||
expected=np.array([0, 7.167038, 5.837730, 4.158883, 2.197225, 0],
|
||||
np.array([[-1, 5, 6, 7, 8, float("NaN")]], dtype=dtype),
|
||||
expected=np.array([[0, 7.167038, 5.837730, 4.158883, 2.197225, 0]],
|
||||
dtype=dtype),
|
||||
rtol=1e-4,
|
||||
atol=1e-6)
|
||||
|
@ -153,6 +153,7 @@ XLA_MAKE_BINARY(Xlogy, XlogyImpl(lhs, rhs, broadcast_helper));
|
||||
|
||||
xla::XlaOp Xlog1pyImpl(xla::XlaOp x, xla::XlaOp y,
|
||||
const BCast& broadcast_helper) {
|
||||
std::tie(x, y) = XlaBinaryOp::Broadcast(x, y, broadcast_helper);
|
||||
auto non_zero = xla::Mul(x, xla::Log1p(y));
|
||||
auto zero = xla::ZerosLike(non_zero);
|
||||
auto x_is_zero = xla::Eq(x, zero);
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/compiler/xla/client/lib/constants.h"
|
||||
#include "tensorflow/compiler/xla/client/lib/math.h"
|
||||
#include "tensorflow/compiler/xla/client/xla_builder.h"
|
||||
#include "tensorflow/compiler/xla/primitive_util.h"
|
||||
#include "tensorflow/core/framework/kernel_def_builder.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -85,8 +86,20 @@ XLAJIT_MAKE_UNARY(Rsqrt, xla::Rsqrt(x));
|
||||
XLAJIT_MAKE_UNARY(Sigmoid, xla::Logistic(x));
|
||||
|
||||
// Returns 0 if x is NaN, 0 if x is 0, -1 if x < 0 and 1 if x > 0.
|
||||
XLAJIT_MAKE_UNARY(Sign,
|
||||
xla::Select(xla::Ne(x, x), xla::ZerosLike(x), xla::Sign(x)));
|
||||
static xla::XlaOp Sign(xla::XlaBuilder* b, xla::XlaOp x) {
|
||||
return b->ReportErrorOrReturn([&]() -> xla::StatusOr<xla::XlaOp> {
|
||||
TF_ASSIGN_OR_RETURN(auto shape, b->GetShape(x));
|
||||
if (xla::primitive_util::IsComplexType(shape.element_type())) {
|
||||
return xla::Sign(x);
|
||||
}
|
||||
auto gt = xla::Gt(x, xla::ZerosLike(x));
|
||||
auto lt = xla::Lt(x, xla::ZerosLike(x));
|
||||
return xla::ConvertElementType(gt, shape.element_type()) -
|
||||
xla::ConvertElementType(lt, shape.element_type());
|
||||
});
|
||||
}
|
||||
|
||||
XLAJIT_MAKE_UNARY(Sign, Sign(b, x));
|
||||
XLAJIT_MAKE_UNARY(Sinh, xla::Sinh(x));
|
||||
|
||||
static xla::XlaOp Softplus(xla::XlaBuilder* b, xla::XlaOp features) {
|
||||
|
@ -3,9 +3,10 @@
|
||||
Caution: Tiled layout is *pre-release* and this describes how it's intended to
|
||||
work. Errors may be silently ignored.
|
||||
|
||||
<center> 
|
||||
|
||||
Figure 1 </center>
|
||||
<p align="center">
|
||||
<img src="images/xla_array_layout_figure1.png">
|
||||
Figure 1
|
||||
</p>
|
||||
|
||||
Figure 1 shows how an array F32[3,5] is laid out in memory with 2x2 tiling. A
|
||||
shape with this layout is written as F32[3,5]{1,0:(2,2)}, where 1,0 relates to
|
||||
@ -120,9 +121,10 @@ element follows the formula above as expected.
|
||||
|
||||
XLA's tiling becomes even more flexible by applying it repeatedly.
|
||||
|
||||
<center> 
|
||||
|
||||
Figure 2 </center>
|
||||
<p align="center">
|
||||
<img src="images/xla_array_layout_figure2.png">
|
||||
Figure 2
|
||||
</p>
|
||||
|
||||
Figure 2 shows how an array of size 4x8 is tiled by two levels of tiling (first
|
||||
2x4 then 2x1). We represent this repeated tiling as (2,4)(2,1). Each color
|
||||
|
@ -1202,6 +1202,9 @@ cc_library(
|
||||
srcs = ["transfer_manager.cc"],
|
||||
hdrs = ["transfer_manager.h"],
|
||||
deps = [
|
||||
":compiler",
|
||||
":executable",
|
||||
":maybe_owning_device_memory",
|
||||
":shaped_buffer",
|
||||
"//tensorflow/compiler/xla:literal",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
@ -1210,8 +1213,6 @@ cc_library(
|
||||
"//tensorflow/compiler/xla:types",
|
||||
"//tensorflow/compiler/xla:util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/service:executable",
|
||||
"//tensorflow/compiler/xla/service:maybe_owning_device_memory",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:stream_executor_no_cuda",
|
||||
"//tensorflow/stream_executor:device_memory",
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include "absl/memory/memory.h"
|
||||
#include "absl/strings/str_cat.h"
|
||||
#include "tensorflow/compiler/xla/service/compiler.h"
|
||||
#include "tensorflow/compiler/xla/service/maybe_owning_device_memory.h"
|
||||
#include "tensorflow/compiler/xla/shape_util.h"
|
||||
#include "tensorflow/compiler/xla/status_macros.h"
|
||||
@ -33,6 +34,7 @@ limitations under the License.
|
||||
using absl::StrCat;
|
||||
|
||||
namespace xla {
|
||||
|
||||
/* static */ tensorflow::mutex
|
||||
TransferManager::platform_transfer_manager_mutex_(
|
||||
tensorflow::LINKER_INITIALIZED);
|
||||
@ -200,6 +202,67 @@ void TransferManager::TransferArrayFromDevice(
|
||||
std::move(done), transfer_metadata);
|
||||
}
|
||||
|
||||
Status TransferManager::ReadDynamicShapes(se::Stream* stream,
|
||||
ShapedBuffer* device_buffer,
|
||||
Shape* host_shape,
|
||||
Shape* device_shape) {
|
||||
DCHECK(device_shape->is_dynamic());
|
||||
Shape original_device_shape = *device_shape;
|
||||
Shape original_host_shape = *host_shape;
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
|
||||
TF_ASSIGN_OR_RETURN(auto compiler,
|
||||
Compiler::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(device_buffer->buffers().ForEachMutableElementWithStatus(
|
||||
[&](const ShapeIndex& index, se::DeviceMemoryBase* buffer) {
|
||||
const Shape& buffer_shape =
|
||||
ShapeUtil::GetSubshape(*device_shape, index);
|
||||
if (buffer_shape.IsTuple()) {
|
||||
return Status::OK();
|
||||
}
|
||||
Shape& host_sub_shape =
|
||||
*ShapeUtil::GetMutableSubshape(host_shape, index);
|
||||
Shape& device_sub_shape =
|
||||
*ShapeUtil::GetMutableSubshape(device_shape, index);
|
||||
if (device_sub_shape.is_static()) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Read the dynamic shape metadata from the device stream.
|
||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||
Shape buffer_shape_static = ShapeUtil::MakeStaticShape(buffer_shape);
|
||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||
if (metadata_size == 0) {
|
||||
return InvalidArgument("Dynamic shape metadata size should not be 0");
|
||||
}
|
||||
auto buffer_8 = se::DeviceMemory<uint8>(*buffer);
|
||||
auto metadata_buffer =
|
||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto metadata,
|
||||
TransferArrayFromDevice(
|
||||
stream,
|
||||
ShapeUtil::MakeShape(S32, {buffer_shape.dimensions_size()}),
|
||||
metadata_buffer));
|
||||
|
||||
// Update shape size from metadata.
|
||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||
host_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
device_sub_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
host_shape->clear_dynamic_dimensions();
|
||||
device_shape->clear_dynamic_dimensions();
|
||||
|
||||
TF_RET_CHECK(ShapeUtil::DynamicShapeIsCompatible(*device_shape,
|
||||
original_device_shape));
|
||||
TF_RET_CHECK(
|
||||
ShapeUtil::DynamicShapeIsCompatible(*host_shape, original_host_shape));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
/* static */ void TransferManager::RegisterTransferManager(
|
||||
se::Platform::Id platform_id,
|
||||
TransferManagerCreationFunction creation_function) {
|
||||
|
@ -184,6 +184,15 @@ class TransferManager {
|
||||
const se::DeviceMemoryBase& source,
|
||||
const TransferMetadata* transfer_metadata = nullptr);
|
||||
|
||||
// Read from a device buffer and update the dynamic dimension sizes of
|
||||
// `host_shape` and `device_shape`. The function takes in bounded dynamic
|
||||
// shapes, and returns static shapes with dynamic shapes updated.
|
||||
// The shape of the buffer also have to be compatible with the host shape and
|
||||
// device shape.
|
||||
virtual Status ReadDynamicShapes(se::Stream* stream,
|
||||
ShapedBuffer* device_buffer,
|
||||
Shape* host_shape, Shape* device_shape);
|
||||
|
||||
// Transfers the given literal into the Infeed interface of the device,
|
||||
// using the given executor.
|
||||
virtual Status TransferLiteralToInfeed(se::StreamExecutor* executor,
|
||||
|
@ -120,7 +120,7 @@ enum Format {
|
||||
}
|
||||
|
||||
// Describes a tile used in tiling-based layout. Refer to
|
||||
// g3doc/third_party/tensorflow/compiler/xla/g3doc/layout_with_tiling.md for
|
||||
// g3doc/third_party/tensorflow/compiler/xla/g3doc/tiled_layout.md for
|
||||
// details about tiling-based layout.
|
||||
message TileProto {
|
||||
// Number of elements in each dimension of the tile. It's ordered from the
|
||||
|
@ -264,86 +264,28 @@ Status UpdateDynamicInputs(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<xla::Literal> ReadMetadataLiteral(
|
||||
se::Stream* stream, se::DeviceMemoryBase buffer,
|
||||
const xla::Shape& buffer_shape, xla::TransferManager* transfer_manager) {
|
||||
TF_ASSIGN_OR_RETURN(auto compiler, xla::Compiler::GetForPlatform(
|
||||
stream->parent()->platform()));
|
||||
auto shape_size_fn = compiler->ShapeSizeBytesFunction();
|
||||
xla::Shape buffer_shape_static =
|
||||
xla::ShapeUtil::MakeStaticShape(buffer_shape);
|
||||
const int64 offset = shape_size_fn(buffer_shape_static);
|
||||
int64 metadata_size = shape_size_fn(buffer_shape) - offset;
|
||||
TF_RET_CHECK(metadata_size != 0);
|
||||
auto buffer_8 = se::DeviceMemory<uint8>(buffer);
|
||||
auto metadata_buffer =
|
||||
stream->parent()->GetSubBuffer(&buffer_8, offset, metadata_size);
|
||||
return transfer_manager->TransferArrayFromDevice(
|
||||
stream,
|
||||
xla::ShapeUtil::MakeShape(xla::S32, {buffer_shape.dimensions_size()}),
|
||||
metadata_buffer);
|
||||
}
|
||||
|
||||
// For each subshape in the result buffer that's dynamic, read the dynamic
|
||||
// dimension sizes from the metadata, and update output shapes. The result shape
|
||||
// is a static and concrete shape.
|
||||
xla::Status UpdateDynamicOutputs(se::Stream* stream,
|
||||
const xla::ShapedBuffer& shaped_buffer,
|
||||
xla::Shape* output_host_shape,
|
||||
xla::Shape* output_device_shape) {
|
||||
DCHECK(output_device_shape->is_dynamic());
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
|
||||
TF_RETURN_IF_ERROR(shaped_buffer.buffers().ForEachElementWithStatus(
|
||||
[&](const xla::ShapeIndex& index, const se::DeviceMemoryBase& buffer) {
|
||||
const xla::Shape& buffer_shape =
|
||||
xla::ShapeUtil::GetSubshape(*output_device_shape, index);
|
||||
if (buffer_shape.IsTuple()) {
|
||||
return Status::OK();
|
||||
}
|
||||
xla::Shape& host_shape =
|
||||
*xla::ShapeUtil::GetMutableSubshape(output_host_shape, index);
|
||||
xla::Shape& device_shape =
|
||||
*xla::ShapeUtil::GetMutableSubshape(output_device_shape, index);
|
||||
if (device_shape.is_static()) {
|
||||
return Status::OK();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto metadata,
|
||||
ReadMetadataLiteral(stream, buffer, buffer_shape,
|
||||
transfer_manager));
|
||||
// Update shape size from metadata.
|
||||
for (int64 i = 0; i < metadata.element_count(); ++i) {
|
||||
host_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
device_shape.mutable_dimensions()[i] = metadata.Get<int32>({i});
|
||||
}
|
||||
return Status::OK();
|
||||
}));
|
||||
output_host_shape->clear_dynamic_dimensions();
|
||||
output_device_shape->clear_dynamic_dimensions();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
xla::StatusOr<RefPtr<XRTTupleAllocation>> CreateOutputTuple(
|
||||
se::Stream* stream, xla::ExecutionOutput run_result, xla::Backend* backend,
|
||||
int device_ordinal) {
|
||||
XRTTupleAllocation* output_tuple;
|
||||
const xla::ScopedShapedBuffer& shaped_buffer = run_result.Result();
|
||||
if (shaped_buffer.on_device_shape().is_dynamic()) {
|
||||
xla::ScopedShapedBuffer* shaped_buffer = run_result.MutableResult();
|
||||
if (shaped_buffer->on_device_shape().is_dynamic()) {
|
||||
// Update dynamic shapes from output buffer, and create a XRT tensor with
|
||||
// dimension sizes read from metadata.
|
||||
xla::Shape output_host_shape = shaped_buffer.on_host_shape();
|
||||
xla::Shape output_device_shape = shaped_buffer.on_device_shape();
|
||||
TF_RETURN_IF_ERROR(UpdateDynamicOutputs(
|
||||
xla::Shape output_host_shape = shaped_buffer->on_host_shape();
|
||||
xla::Shape output_device_shape = shaped_buffer->on_device_shape();
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
auto transfer_manager,
|
||||
xla::TransferManager::GetForPlatform(stream->parent()->platform()));
|
||||
TF_RETURN_IF_ERROR(transfer_manager->ReadDynamicShapes(
|
||||
stream, shaped_buffer, &output_host_shape, &output_device_shape));
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||
*shaped_buffer, output_host_shape, output_device_shape, backend,
|
||||
device_ordinal, &output_tuple));
|
||||
} else {
|
||||
// Fast-path: Don't copy shapes of output buffer.
|
||||
TF_RETURN_IF_ERROR(XRTTupleAllocation::CreateFromBuffer(
|
||||
shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||
*shaped_buffer, backend, device_ordinal, &output_tuple));
|
||||
}
|
||||
// After the output tuple is created, we can release the output result
|
||||
// buffers, to make sure they won't be cleared by its destructor.
|
||||
|
@ -197,7 +197,7 @@ Status EagerOperation::Execute(absl::Span<AbstractTensorHandle*> retvals,
|
||||
if (device == kVariantDeviceNull) {
|
||||
bool pin_to_cpu;
|
||||
TF_RETURN_IF_ERROR(eager::MaybePinSmallOpsToCpu(
|
||||
&pin_to_cpu, op_name(),
|
||||
&pin_to_cpu, Name(),
|
||||
absl::MakeSpan(
|
||||
reinterpret_cast<ImmediateExecutionTensorHandle**>(inputs_.data()),
|
||||
inputs_.size()),
|
||||
|
@ -764,8 +764,8 @@ Tensor GetResourceHandle(const string& var_name, const string& container,
|
||||
handle.set_device(device_name);
|
||||
handle.set_container(container);
|
||||
handle.set_name(var_name);
|
||||
handle.set_hash_code(MakeTypeIndex<Var>().hash_code());
|
||||
handle.set_maybe_type_name(MakeTypeIndex<Var>().name());
|
||||
handle.set_hash_code(TypeIndex::Make<Var>().hash_code());
|
||||
handle.set_maybe_type_name(TypeIndex::Make<Var>().name());
|
||||
Tensor tensor(DT_RESOURCE, TensorShape({}));
|
||||
tensor.scalar<ResourceHandle>()() = handle;
|
||||
return tensor;
|
||||
|
@ -301,7 +301,7 @@ ResourceHandle MakeResourceHandle(
|
||||
return MakeResourceHandle(
|
||||
container.empty() ? ctx->resource_manager()->default_container()
|
||||
: container,
|
||||
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
|
||||
name, *ctx->device(), TypeIndex::Make<T>(), dtypes_and_shapes);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -311,7 +311,7 @@ ResourceHandle MakeResourceHandle(
|
||||
return MakeResourceHandle(
|
||||
container.empty() ? ctx->resource_manager()->default_container()
|
||||
: container,
|
||||
name, *ctx->device(), MakeTypeIndex<T>(), dtypes_and_shapes);
|
||||
name, *ctx->device(), TypeIndex::Make<T>(), dtypes_and_shapes);
|
||||
}
|
||||
|
||||
Status MakeResourceHandleToOutput(OpKernelContext* context, int output_index,
|
||||
@ -589,7 +589,7 @@ Status ResourceMgr::Create(const string& container, const string& name,
|
||||
CheckDeriveFromResourceBase<T>();
|
||||
CHECK(resource != nullptr);
|
||||
mutex_lock l(mu_);
|
||||
return DoCreate(container, MakeTypeIndex<T>(), name, resource);
|
||||
return DoCreate(container, TypeIndex::Make<T>(), name, resource);
|
||||
}
|
||||
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
@ -635,7 +635,7 @@ template <typename T, bool use_dynamic_cast>
|
||||
Status ResourceMgr::LookupInternal(const string& container, const string& name,
|
||||
T** resource) const {
|
||||
ResourceBase* found = nullptr;
|
||||
Status s = DoLookup(container, MakeTypeIndex<T>(), name, &found);
|
||||
Status s = DoLookup(container, TypeIndex::Make<T>(), name, &found);
|
||||
if (s.ok()) {
|
||||
// It's safe to down cast 'found' to T* since
|
||||
// typeid(T).hash_code() is part of the map key.
|
||||
@ -660,7 +660,7 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
||||
s = LookupInternal<T, use_dynamic_cast>(container, name, resource);
|
||||
if (s.ok()) return s;
|
||||
TF_RETURN_IF_ERROR(creator(resource));
|
||||
s = DoCreate(container, MakeTypeIndex<T>(), name, *resource);
|
||||
s = DoCreate(container, TypeIndex::Make<T>(), name, *resource);
|
||||
if (!s.ok()) {
|
||||
return errors::Internal("LookupOrCreate failed unexpectedly");
|
||||
}
|
||||
@ -671,7 +671,7 @@ Status ResourceMgr::LookupOrCreate(const string& container, const string& name,
|
||||
template <typename T>
|
||||
Status ResourceMgr::Delete(const string& container, const string& name) {
|
||||
CheckDeriveFromResourceBase<T>();
|
||||
return DoDelete(container, MakeTypeIndex<T>(), name);
|
||||
return DoDelete(container, TypeIndex::Make<T>(), name);
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
@ -710,7 +710,7 @@ Status ValidateDevice(OpKernelContext* ctx, const ResourceHandle& p);
|
||||
template <typename T>
|
||||
Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
|
||||
TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
|
||||
auto type_index = MakeTypeIndex<T>();
|
||||
auto type_index = TypeIndex::Make<T>();
|
||||
if (type_index.hash_code() != p.hash_code()) {
|
||||
return errors::InvalidArgument(
|
||||
"Trying to access resource using the wrong type. Expected ",
|
||||
@ -883,7 +883,7 @@ ResourceHandle ScopedStepContainer::MakeResourceHandle(
|
||||
mutex_lock ml(mu_);
|
||||
dirty_ = true;
|
||||
return tensorflow::MakeResourceHandle(container_, name, device,
|
||||
MakeTypeIndex<T>(), {});
|
||||
TypeIndex::Make<T>(), {});
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -105,7 +105,7 @@ class ResourceOpKernel : public OpKernel {
|
||||
if (has_resource_type_) {
|
||||
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
|
||||
context, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<T>()));
|
||||
TypeIndex::Make<T>()));
|
||||
} else {
|
||||
context->set_output_ref(0, &mu_, handle_.AccessTensor(context));
|
||||
}
|
||||
|
@ -42,11 +42,15 @@ void ExpectClose(const Tensor& x, const Tensor& y, double atol, double rtol) {
|
||||
<< "typed_atol is negative: " << typed_atol;
|
||||
ASSERT_GE(typed_rtol, static_cast<RealType>(0.0))
|
||||
<< "typed_rtol is negative: " << typed_rtol;
|
||||
const int max_failures = 10;
|
||||
int num_failures = 0;
|
||||
for (int i = 0; i < size; ++i) {
|
||||
EXPECT_TRUE(
|
||||
internal::Helper<T>::IsClose(Tx[i], Ty[i], typed_atol, typed_rtol))
|
||||
<< "index = " << i << " x = " << Tx[i] << " y = " << Ty[i]
|
||||
<< " typed_atol = " << typed_atol << " typed_rtol = " << typed_rtol;
|
||||
<< "index = " << (++num_failures, i) << " x = " << Tx[i]
|
||||
<< " y = " << Ty[i] << " typed_atol = " << typed_atol
|
||||
<< " typed_rtol = " << typed_rtol;
|
||||
ASSERT_LT(num_failures, max_failures) << "Too many mismatches, giving up.";
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -95,11 +95,6 @@ class TypeIndex {
|
||||
const char* name_;
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
inline TypeIndex MakeTypeIndex() {
|
||||
return TypeIndex::Make<T>();
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_FRAMEWORK_TYPE_INDEX_H_
|
||||
|
@ -144,7 +144,7 @@ void EncodeVariant(const T& value, string* buf);
|
||||
// Variant y_type_unknown = serialized_proto_f; // Store serialized Variant.
|
||||
//
|
||||
// EXPECT_EQ(x.TypeName(), y_type_unknown.TypeName()); // Looks like Foo.
|
||||
// EXPECT_EQ(MakeTypeIndex<VariantTensorDataProto>(),
|
||||
// EXPECT_EQ(TypeIndex::Make<VariantTensorDataProto>(),
|
||||
// y_type_unknown.TypeId());
|
||||
//
|
||||
class Variant {
|
||||
@ -227,7 +227,7 @@ class Variant {
|
||||
// of the original type when a TensorValueDataProto is stored as the
|
||||
// value. In this case, it returns the TypeIndex of TensorValueDataProto.
|
||||
TypeIndex TypeId() const {
|
||||
const TypeIndex VoidTypeIndex = MakeTypeIndex<void>();
|
||||
const TypeIndex VoidTypeIndex = TypeIndex::Make<void>();
|
||||
if (is_empty()) {
|
||||
return VoidTypeIndex;
|
||||
}
|
||||
@ -244,7 +244,7 @@ class Variant {
|
||||
// otherwise.
|
||||
template <typename T>
|
||||
T* get() {
|
||||
const TypeIndex TTypeIndex = MakeTypeIndex<T>();
|
||||
const TypeIndex TTypeIndex = TypeIndex::Make<T>();
|
||||
if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
|
||||
return std::addressof(static_cast<Variant::Value<T>*>(GetValue())->value);
|
||||
}
|
||||
@ -253,7 +253,7 @@ class Variant {
|
||||
// otherwise.
|
||||
template <typename T>
|
||||
const T* get() const {
|
||||
const TypeIndex TTypeIndex = MakeTypeIndex<T>();
|
||||
const TypeIndex TTypeIndex = TypeIndex::Make<T>();
|
||||
if (is_empty() || (TTypeIndex != TypeId())) return nullptr;
|
||||
return std::addressof(
|
||||
static_cast<const Variant::Value<T>*>(GetValue())->value);
|
||||
@ -333,7 +333,7 @@ class Variant {
|
||||
|
||||
TypeIndex TypeId() const final {
|
||||
const TypeIndex value_type_index =
|
||||
MakeTypeIndex<typename std::decay<T>::type>();
|
||||
TypeIndex::Make<typename std::decay<T>::type>();
|
||||
return value_type_index;
|
||||
}
|
||||
|
||||
|
@ -160,7 +160,7 @@ string TypeNameVariantImpl(
|
||||
const T& value,
|
||||
TypeNameResolver<T, false /* has_type_name */, false /* Tensor */,
|
||||
false /* protobuf */>) {
|
||||
return port::MaybeAbiDemangle(MakeTypeIndex<T>().name());
|
||||
return port::MaybeAbiDemangle(TypeIndex::Make<T>().name());
|
||||
}
|
||||
|
||||
template <typename T>
|
||||
|
@ -521,7 +521,7 @@ class UnaryVariantBinaryOpRegistration {
|
||||
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION(T, direction, \
|
||||
device_copy_fn) \
|
||||
INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
|
||||
__COUNTER__, T, direction, MakeTypeIndex<T>(), device_copy_fn)
|
||||
__COUNTER__, T, direction, TypeIndex::Make<T>(), device_copy_fn)
|
||||
|
||||
#define INTERNAL_REGISTER_UNARY_VARIANT_DEVICE_COPY_FUNCTION_UNIQ_HELPER( \
|
||||
ctr, T, direction, type_index, device_copy_fn) \
|
||||
@ -542,7 +542,7 @@ class UnaryVariantBinaryOpRegistration {
|
||||
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION(op, device, T, \
|
||||
unary_op_function) \
|
||||
REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||
__COUNTER__, op, device, T, MakeTypeIndex<T>(), unary_op_function)
|
||||
__COUNTER__, op, device, T, TypeIndex::Make<T>(), unary_op_function)
|
||||
|
||||
#define REGISTER_UNARY_VARIANT_UNARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||
ctr, op, device, T, type_index, unary_op_function) \
|
||||
@ -563,7 +563,7 @@ class UnaryVariantBinaryOpRegistration {
|
||||
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION(op, device, T, \
|
||||
binary_op_function) \
|
||||
REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||
__COUNTER__, op, device, T, MakeTypeIndex<T>(), binary_op_function)
|
||||
__COUNTER__, op, device, T, TypeIndex::Make<T>(), binary_op_function)
|
||||
|
||||
#define REGISTER_UNARY_VARIANT_BINARY_OP_FUNCTION_UNIQ_HELPER( \
|
||||
ctr, op, device, T, type_index, binary_op_function) \
|
||||
|
@ -155,12 +155,12 @@ TEST(VariantOpCopyToGPURegistryTest, TestBasic) {
|
||||
// No registered copy fn for GPU<->GPU.
|
||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
|
||||
VariantDeviceCopyDirection::DEVICE_TO_DEVICE,
|
||||
MakeTypeIndex<VariantValue>()),
|
||||
TypeIndex::Make<VariantValue>()),
|
||||
nullptr);
|
||||
|
||||
auto* copy_to_gpu_fn = UnaryVariantOpRegistry::Global()->GetDeviceCopyFn(
|
||||
VariantDeviceCopyDirection::HOST_TO_DEVICE,
|
||||
MakeTypeIndex<VariantValue>());
|
||||
TypeIndex::Make<VariantValue>());
|
||||
EXPECT_NE(copy_to_gpu_fn, nullptr);
|
||||
|
||||
VariantValue vv{true /* early_exit */};
|
||||
@ -183,7 +183,7 @@ TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
|
||||
UnaryVariantOpRegistry registry;
|
||||
UnaryVariantOpRegistry::AsyncVariantDeviceCopyFn f;
|
||||
class FjFjFj {};
|
||||
const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
|
||||
const auto kTypeIndex = TypeIndex::Make<FjFjFj>();
|
||||
registry.RegisterDeviceCopyFn(VariantDeviceCopyDirection::HOST_TO_DEVICE,
|
||||
kTypeIndex, f);
|
||||
EXPECT_DEATH(registry.RegisterDeviceCopyFn(
|
||||
@ -193,9 +193,10 @@ TEST(VariantOpCopyToGPURegistryTest, TestDuplicate) {
|
||||
|
||||
TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
|
||||
class Blah {};
|
||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
|
||||
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
|
||||
nullptr);
|
||||
EXPECT_EQ(
|
||||
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
|
||||
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU, TypeIndex::Make<Blah>()),
|
||||
nullptr);
|
||||
|
||||
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
|
||||
Variant v = vv_early_exit;
|
||||
@ -218,9 +219,10 @@ TEST(VariantOpZerosLikeRegistryTest, TestBasicCPU) {
|
||||
#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
|
||||
TEST(VariantOpUnaryOpRegistryTest, TestBasicGPU) {
|
||||
class Blah {};
|
||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
|
||||
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
|
||||
nullptr);
|
||||
EXPECT_EQ(
|
||||
UnaryVariantOpRegistry::Global()->GetUnaryOpFn(
|
||||
ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_GPU, TypeIndex::Make<Blah>()),
|
||||
nullptr);
|
||||
|
||||
VariantValue vv_early_exit{true /* early_exit */, 0 /* value */};
|
||||
Variant v = vv_early_exit;
|
||||
@ -245,7 +247,7 @@ TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
|
||||
UnaryVariantOpRegistry registry;
|
||||
UnaryVariantOpRegistry::VariantUnaryOpFn f;
|
||||
class FjFjFj {};
|
||||
const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
|
||||
const auto kTypeIndex = TypeIndex::Make<FjFjFj>();
|
||||
|
||||
registry.RegisterUnaryOpFn(ZEROS_LIKE_VARIANT_UNARY_OP, DEVICE_CPU,
|
||||
kTypeIndex, f);
|
||||
@ -263,7 +265,7 @@ TEST(VariantOpUnaryOpRegistryTest, TestDuplicate) {
|
||||
TEST(VariantOpAddRegistryTest, TestBasicCPU) {
|
||||
class Blah {};
|
||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
|
||||
ADD_VARIANT_BINARY_OP, DEVICE_CPU, MakeTypeIndex<Blah>()),
|
||||
ADD_VARIANT_BINARY_OP, DEVICE_CPU, TypeIndex::Make<Blah>()),
|
||||
nullptr);
|
||||
|
||||
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
|
||||
@ -290,7 +292,7 @@ TEST(VariantOpAddRegistryTest, TestBasicCPU) {
|
||||
TEST(VariantOpAddRegistryTest, TestBasicGPU) {
|
||||
class Blah {};
|
||||
EXPECT_EQ(UnaryVariantOpRegistry::Global()->GetBinaryOpFn(
|
||||
ADD_VARIANT_BINARY_OP, DEVICE_GPU, MakeTypeIndex<Blah>()),
|
||||
ADD_VARIANT_BINARY_OP, DEVICE_GPU, TypeIndex::Make<Blah>()),
|
||||
nullptr);
|
||||
|
||||
VariantValue vv_early_exit{true /* early_exit */, 3 /* value */};
|
||||
@ -318,7 +320,7 @@ TEST(VariantOpAddRegistryTest, TestDuplicate) {
|
||||
UnaryVariantOpRegistry registry;
|
||||
UnaryVariantOpRegistry::VariantBinaryOpFn f;
|
||||
class FjFjFj {};
|
||||
const auto kTypeIndex = MakeTypeIndex<FjFjFj>();
|
||||
const auto kTypeIndex = TypeIndex::Make<FjFjFj>();
|
||||
|
||||
registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU, kTypeIndex, f);
|
||||
EXPECT_DEATH(registry.RegisterBinaryOpFn(ADD_VARIANT_BINARY_OP, DEVICE_CPU,
|
||||
|
@ -589,7 +589,7 @@ TEST(VariantTest, TensorListTest) {
|
||||
serialized.ToProto(&data);
|
||||
const Variant y_unknown = data;
|
||||
EXPECT_EQ(y_unknown.TypeName(), "TensorList");
|
||||
EXPECT_EQ(y_unknown.TypeId(), MakeTypeIndex<VariantTensorDataProto>());
|
||||
EXPECT_EQ(y_unknown.TypeId(), TypeIndex::Make<VariantTensorDataProto>());
|
||||
EXPECT_EQ(y_unknown.DebugString(),
|
||||
strings::StrCat(
|
||||
"Variant<type: TensorList value: ", data.DebugString(), ">"));
|
||||
|
@ -323,6 +323,7 @@ cc_library(
|
||||
":cost_estimator",
|
||||
":op_context",
|
||||
":utils",
|
||||
"@com_google_absl//absl/strings",
|
||||
"//third_party/eigen3",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
|
@ -16,6 +16,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||
|
||||
#include "absl/strings/match.h"
|
||||
#include "third_party/eigen3/Eigen/Core"
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/attr_value_util.h"
|
||||
@ -23,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/grappler/clusters/utils.h"
|
||||
#include "tensorflow/core/grappler/costs/op_context.h"
|
||||
#include "tensorflow/core/grappler/costs/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -101,16 +103,16 @@ static const Costs::Duration kMinComputeTime(1);
|
||||
|
||||
namespace {
|
||||
|
||||
string GetDataFormat(const OpInfo& op_info) {
|
||||
string data_format = "NHWC"; // Default format.
|
||||
std::string GetDataFormat(const OpInfo& op_info) {
|
||||
std::string data_format = "NHWC"; // Default format.
|
||||
if (op_info.attr().find("data_format") != op_info.attr().end()) {
|
||||
data_format = op_info.attr().at("data_format").s();
|
||||
}
|
||||
return data_format;
|
||||
}
|
||||
|
||||
string GetFilterFormat(const OpInfo& op_info) {
|
||||
string filter_format = "HWIO"; // Default format.
|
||||
std::string GetFilterFormat(const OpInfo& op_info) {
|
||||
std::string filter_format = "HWIO"; // Default format.
|
||||
if (op_info.attr().find("filter_format") != op_info.attr().end()) {
|
||||
filter_format = op_info.attr().at("filter_format").s();
|
||||
}
|
||||
@ -202,7 +204,7 @@ int64 CwiseOutputElementCount(const TensorShapeProto& input_shape_1,
|
||||
|
||||
// Helper function for determining whether there are repeated indices in the
|
||||
// input Einsum equation.
|
||||
bool CheckRepeatedDimensions(const string& dim_str) {
|
||||
bool CheckRepeatedDimensions(const absl::string_view dim_str) {
|
||||
int str_size = dim_str.size();
|
||||
for (int idx = 0; idx < str_size - 1; idx++) {
|
||||
if (dim_str.find(dim_str[idx], idx + 1) != std::string::npos) {
|
||||
@ -212,6 +214,75 @@ bool CheckRepeatedDimensions(const string& dim_str) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// Auxiliary function for determining whether OpLevelCostEstimator is compatible
|
||||
// with a given Einsum.
|
||||
bool IsEinsumCorrectlyFormed(const OpContext& einsum_context) {
|
||||
const auto& op_info = einsum_context.op_info;
|
||||
|
||||
auto it = op_info.attr().find("equation");
|
||||
if (it == op_info.attr().end()) return false;
|
||||
const absl::string_view equation = it->second.s();
|
||||
std::vector<std::string> equation_split = absl::StrSplit(equation, "->");
|
||||
|
||||
if (equation_split.empty()) {
|
||||
LOG(WARNING) << "Einsum with malformed equation";
|
||||
return false;
|
||||
}
|
||||
std::vector<absl::string_view> input_split =
|
||||
absl::StrSplit(equation_split[0], ',');
|
||||
|
||||
// The current model covers Einsum operations with two operands and a RHS
|
||||
if (op_info.inputs_size() != 2 || equation_split.size() != 2) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||
return false;
|
||||
}
|
||||
const auto& a_input = op_info.inputs(0);
|
||||
const auto& b_input = op_info.inputs(1);
|
||||
absl::string_view rhs_str = equation_split[1];
|
||||
absl::string_view a_input_str = input_split[0];
|
||||
absl::string_view b_input_str = input_split[1];
|
||||
|
||||
// Ellipsis are not currently supported
|
||||
if (absl::StrContains(a_input_str, "...") ||
|
||||
absl::StrContains(b_input_str, "...")) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
|
||||
<< ", ellipsis not supported";
|
||||
return false;
|
||||
}
|
||||
|
||||
constexpr int kMatrixRank = 2;
|
||||
|
||||
bool a_input_shape_unknown = false;
|
||||
bool b_input_shape_unknown = false;
|
||||
|
||||
TensorShapeProto a_input_shape = MaybeGetMinimumShape(
|
||||
a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
|
||||
&a_input_shape_unknown);
|
||||
TensorShapeProto b_input_shape = MaybeGetMinimumShape(
|
||||
b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
|
||||
&b_input_shape_unknown);
|
||||
|
||||
if (a_input_str.size() != static_cast<size_t>(a_input_shape.dim_size()) ||
|
||||
b_input_str.size() != static_cast<size_t>(b_input_shape.dim_size())) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
|
||||
<< ", equation subscripts don't match tensor rank.";
|
||||
return false;
|
||||
}
|
||||
|
||||
// Subscripts where axis appears more than once for a single input are not yet
|
||||
// supported
|
||||
if (CheckRepeatedDimensions(a_input_str) ||
|
||||
CheckRepeatedDimensions(b_input_str) ||
|
||||
CheckRepeatedDimensions(rhs_str)) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
|
||||
<< ", Subscripts where axis appears more than once for a single "
|
||||
"input are not yet supported";
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Return a minimum shape if the shape is unknown. If known, return the original
|
||||
@ -528,7 +599,7 @@ DeviceInfo OpLevelCostEstimator::GetDeviceInfo(
|
||||
}
|
||||
}
|
||||
} else if (device.type() == "GPU") {
|
||||
const string architecture = device.environment().at("architecture");
|
||||
const std::string architecture = device.environment().at("architecture");
|
||||
int cores_per_multiprocessor;
|
||||
if (architecture < "3") {
|
||||
// Fermi
|
||||
@ -695,7 +766,7 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
|
||||
VLOG(2) << "Original filter shape: " << original_filter_shape.DebugString();
|
||||
|
||||
int x_index, y_index, major_channel_index, minor_channel_index = -1;
|
||||
const string& data_format = GetDataFormat(op_info);
|
||||
const std::string& data_format = GetDataFormat(op_info);
|
||||
if (data_format == "NCHW") {
|
||||
major_channel_index = 1;
|
||||
y_index = 2;
|
||||
@ -712,7 +783,7 @@ OpLevelCostEstimator::ConvolutionDimensionsFromInputs(
|
||||
x_index = 2;
|
||||
major_channel_index = 3;
|
||||
}
|
||||
const string& filter_format = GetFilterFormat(op_info);
|
||||
const std::string& filter_format = GetFilterFormat(op_info);
|
||||
int filter_x_index, filter_y_index, in_major_channel_index, out_channel_index,
|
||||
in_minor_channel_index = -1;
|
||||
if (filter_format == "HWIO") {
|
||||
@ -906,6 +977,130 @@ int64 OpLevelCostEstimator::CountMatMulOperations(const OpInfo& op_info,
|
||||
return ops;
|
||||
}
|
||||
|
||||
bool OpLevelCostEstimator::GenerateBatchMatmulContextFromEinsum(
|
||||
const OpContext& einsum_context, OpContext* batch_matmul_context,
|
||||
bool* found_unknown_shapes) const {
|
||||
// This auxiliary function transforms an einsum OpContext into its equivalent
|
||||
// Batch Matmul OpContext. The function returns a boolean, which determines
|
||||
// whether it was successful in generating the output OpContext or not.
|
||||
|
||||
// Einsum computes a generalized contraction between tensors of arbitrary
|
||||
// dimension as defined by the equation written in the Einstein summation
|
||||
// convention. The number of tensors in the computation and the number of
|
||||
// contractions can be arbitrarily long. The current model only contemplates
|
||||
// Einsum equations, which can be translated into a single BatchMatMul
|
||||
// operation. Einsum operations with more than two operands are not currently
|
||||
// supported. Subscripts where an axis appears more than once for a single
|
||||
// input and ellipsis are currently also excluded. See:
|
||||
// https://www.tensorflow.org/api_docs/python/tf/einsum
|
||||
// We distinguish four kinds of dimensions, depending on their placement in
|
||||
// the equation:
|
||||
// + B: Batch dimensions: Dimensions which appear in both operands and RHS.
|
||||
// + K: Contracting dimensions: These appear in both inputs but not RHS.
|
||||
// + M: Operand A dimensions: These appear in the first operand and the RHS.
|
||||
// + N: Operand B dimensions: These appear in the second operand and the RHS.
|
||||
// Then, the operation to estimate is BatchMatMul([B,M,K],[B,K,N])
|
||||
|
||||
if (batch_matmul_context == nullptr) {
|
||||
VLOG(1) << "Output context should not be a nullptr.";
|
||||
return false;
|
||||
}
|
||||
if (!IsEinsumCorrectlyFormed(einsum_context)) return false;
|
||||
const auto& op_info = einsum_context.op_info;
|
||||
std::vector<std::string> equation_split =
|
||||
absl::StrSplit(op_info.attr().find("equation")->second.s(), "->");
|
||||
std::vector<absl::string_view> input_split =
|
||||
absl::StrSplit(equation_split[0], ',');
|
||||
const auto& a_input = op_info.inputs(0);
|
||||
const auto& b_input = op_info.inputs(1);
|
||||
absl::string_view rhs_str = equation_split[1];
|
||||
absl::string_view a_input_str = input_split[0];
|
||||
absl::string_view b_input_str = input_split[1];
|
||||
|
||||
constexpr int kMatrixRank = 2;
|
||||
|
||||
bool a_input_shape_unknown = false;
|
||||
bool b_input_shape_unknown = false;
|
||||
|
||||
TensorShapeProto a_input_shape = MaybeGetMinimumShape(
|
||||
a_input.shape(), std::max(kMatrixRank, a_input.shape().dim_size()),
|
||||
&a_input_shape_unknown);
|
||||
TensorShapeProto b_input_shape = MaybeGetMinimumShape(
|
||||
b_input.shape(), std::max(kMatrixRank, b_input.shape().dim_size()),
|
||||
&b_input_shape_unknown);
|
||||
|
||||
*found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
|
||||
(a_input.shape().dim_size() < kMatrixRank) ||
|
||||
(b_input.shape().dim_size() < kMatrixRank);
|
||||
|
||||
OpInfo batch_matmul_op_info = op_info;
|
||||
batch_matmul_op_info.mutable_inputs()->Clear();
|
||||
batch_matmul_op_info.set_op("BatchMatMul");
|
||||
|
||||
AttrValue transpose_attribute;
|
||||
transpose_attribute.set_b(false);
|
||||
(*batch_matmul_op_info.mutable_attr())["transpose_a"] = transpose_attribute;
|
||||
(*batch_matmul_op_info.mutable_attr())["transpose_b"] = transpose_attribute;
|
||||
|
||||
OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs();
|
||||
TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
|
||||
a_matrix->set_dtype(a_input.dtype());
|
||||
|
||||
OpInfo::TensorProperties* b_matrix = batch_matmul_op_info.add_inputs();
|
||||
b_matrix->set_dtype(b_input.dtype());
|
||||
TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
|
||||
|
||||
TensorShapeProto_Dim m_dim;
|
||||
TensorShapeProto_Dim n_dim;
|
||||
TensorShapeProto_Dim k_dim;
|
||||
|
||||
m_dim.set_size(1);
|
||||
n_dim.set_size(1);
|
||||
k_dim.set_size(1);
|
||||
|
||||
for (int i_idx = 0, a_input_str_size = a_input_str.size();
|
||||
i_idx < a_input_str_size; ++i_idx) {
|
||||
if (b_input_str.find(a_input_str[i_idx]) == std::string::npos) {
|
||||
if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||
return false;
|
||||
}
|
||||
|
||||
m_dim.set_size(m_dim.size() * a_input_shape.dim(i_idx).size());
|
||||
continue;
|
||||
} else if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
|
||||
// The dimension does not appear in the RHS, therefore it is a contracting
|
||||
// dimension.
|
||||
k_dim.set_size(k_dim.size() * a_input_shape.dim(i_idx).size());
|
||||
continue;
|
||||
}
|
||||
// It appears in both input operands, therefore we place it as an outer
|
||||
// dimension for the Batch Matmul.
|
||||
*(a_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
|
||||
*(b_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
|
||||
}
|
||||
for (int i_idx = 0, b_input_str_size = b_input_str.size();
|
||||
i_idx < b_input_str_size; ++i_idx) {
|
||||
if (a_input_str.find(b_input_str[i_idx]) == std::string::npos) {
|
||||
if (rhs_str.find(b_input_str[i_idx]) == std::string::npos) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||
return false;
|
||||
}
|
||||
n_dim.set_size(n_dim.size() * b_input_shape.dim(i_idx).size());
|
||||
}
|
||||
}
|
||||
|
||||
// The two inner-most dimensions of the Batch Matmul are added.
|
||||
*(a_matrix_shape->add_dim()) = m_dim;
|
||||
*(a_matrix_shape->add_dim()) = k_dim;
|
||||
*(b_matrix_shape->add_dim()) = k_dim;
|
||||
*(b_matrix_shape->add_dim()) = n_dim;
|
||||
|
||||
*batch_matmul_context = einsum_context;
|
||||
batch_matmul_context->op_info = batch_matmul_op_info;
|
||||
return true;
|
||||
}
|
||||
|
||||
int64 OpLevelCostEstimator::CountBatchMatMulOperations(
|
||||
const OpInfo& op_info, bool* found_unknown_shapes) {
|
||||
return CountBatchMatMulOperations(op_info, nullptr, found_unknown_shapes);
|
||||
@ -1327,7 +1522,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
|
||||
// contrib/fused_conv/kernels/fused_conv2d_bias_activation_op.cc
|
||||
|
||||
// TODO(yaozhang): Support NHWC_VECT_W.
|
||||
string data_format = GetDataFormat(op_context.op_info);
|
||||
std::string data_format = GetDataFormat(op_context.op_info);
|
||||
if (data_format != "NCHW" && data_format != "NHWC" &&
|
||||
data_format != "NCHW_VECT_C") {
|
||||
LOG(WARNING) << "unsupported data format: " << data_format;
|
||||
@ -1335,7 +1530,7 @@ Costs OpLevelCostEstimator::PredictFusedConv2DBiasActivation(
|
||||
cost.inaccurate = true;
|
||||
return cost;
|
||||
}
|
||||
string filter_format = GetFilterFormat(op_context.op_info);
|
||||
std::string filter_format = GetFilterFormat(op_context.op_info);
|
||||
if (filter_format != "HWIO" && filter_format != "OIHW" &&
|
||||
filter_format != "OIHW_VECT_I") {
|
||||
LOG(WARNING) << "unsupported filter format: " << filter_format;
|
||||
@ -1405,154 +1600,17 @@ Costs OpLevelCostEstimator::PredictMatMul(const OpContext& op_context) const {
|
||||
}
|
||||
|
||||
Costs OpLevelCostEstimator::PredictEinsum(const OpContext& op_context) const {
|
||||
// Einsum computes a generalized contraction between tensors of arbitrary
|
||||
// dimension as defined by the equation written in the Einstein summation
|
||||
// convention. The number of tensors in the computation and the number of
|
||||
// contractions can be arbitrarily long. The current model only contemplates
|
||||
// Einsum equations, which can be translated into a single BatchMatMul
|
||||
// operation. Einsum operations with more than two operands are not currently
|
||||
// supported. Subscripts where an axis appears more than once for a single
|
||||
// input and ellipsis are currently also excluded. See:
|
||||
// https://www.tensorflow.org/api_docs/python/tf/einsum
|
||||
// We distinguish four kinds of dimensions, depending on their placement in
|
||||
// the equation:
|
||||
// + B: Batch dimensions: Dimensions which appear in both operands and RHS.
|
||||
// + K: Contracting dimensions: These appear in both inputs but not RHS.
|
||||
// + M: Operand A dimensions: These appear in the first operand and the RHS.
|
||||
// + N: Operand B dimensions: These appear in the second operand and the RHS.
|
||||
// Then, the operation to estimate is BatchMatMul([B,M,K],[B,K,N])
|
||||
const auto& op_info = op_context.op_info;
|
||||
|
||||
auto it = op_info.attr().find("equation");
|
||||
if (it == op_info.attr().end()) return Costs::ZeroCosts(/*inaccurate=*/true);
|
||||
const string& equation = it->second.s();
|
||||
std::vector<string> equation_split = absl::StrSplit(equation, "->");
|
||||
|
||||
if (equation_split.empty()) {
|
||||
LOG(WARNING) << "Einsum with malformed equation";
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
std::vector<string> input_split = absl::StrSplit(equation_split[0], ',');
|
||||
|
||||
// The current model covers Einsum operations with two operands and a RHS
|
||||
if (op_info.inputs_size() != 2 || equation_split.size() != 2) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
string rhs_str = equation_split[1];
|
||||
string a_input_str = input_split[0];
|
||||
string b_input_str = input_split[1];
|
||||
|
||||
// Ellipsis are not currently supported
|
||||
if (a_input_str.find("...") != std::string::npos ||
|
||||
b_input_str.find("...") != std::string::npos) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
|
||||
<< ", ellipsis not supported";
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
|
||||
const auto& a_input = op_info.inputs(0);
|
||||
const auto& b_input = op_info.inputs(1);
|
||||
const int matrix_rank = 2;
|
||||
|
||||
OpContext batch_matmul_op_context;
|
||||
bool found_unknown_shapes = false;
|
||||
bool a_input_shape_unknown = false;
|
||||
bool b_input_shape_unknown = false;
|
||||
|
||||
TensorShapeProto a_input_shape = MaybeGetMinimumShape(
|
||||
a_input.shape(), std::max(matrix_rank, a_input.shape().dim_size()),
|
||||
&a_input_shape_unknown);
|
||||
TensorShapeProto b_input_shape = MaybeGetMinimumShape(
|
||||
b_input.shape(), std::max(matrix_rank, b_input.shape().dim_size()),
|
||||
&b_input_shape_unknown);
|
||||
|
||||
found_unknown_shapes = a_input_shape_unknown || b_input_shape_unknown ||
|
||||
(a_input.shape().dim_size() < matrix_rank) ||
|
||||
(b_input.shape().dim_size() < matrix_rank);
|
||||
|
||||
if (a_input_str.size() != static_cast<size_t>(a_input_shape.dim_size()) ||
|
||||
b_input_str.size() != static_cast<size_t>(b_input_shape.dim_size())) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
|
||||
<< ", equation subscripts don't match tensor rank.";
|
||||
bool success = GenerateBatchMatmulContextFromEinsum(
|
||||
op_context, &batch_matmul_op_context, &found_unknown_shapes);
|
||||
if (!success) {
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
|
||||
// Subscripts where axis appears more than once for a single input are not yet
|
||||
// supported
|
||||
if (CheckRepeatedDimensions(a_input_str) ||
|
||||
CheckRepeatedDimensions(b_input_str) ||
|
||||
CheckRepeatedDimensions(rhs_str)) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op()
|
||||
<< ", Subscripts where axis appears more than once for a single "
|
||||
"input are not yet supported";
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
|
||||
OpInfo batch_matmul_op_info = op_info;
|
||||
batch_matmul_op_info.mutable_inputs()->Clear();
|
||||
batch_matmul_op_info.set_op("BatchMatMul");
|
||||
|
||||
AttrValue transpose_attribute;
|
||||
transpose_attribute.set_b(false);
|
||||
(*batch_matmul_op_info.mutable_attr())["transpose_a"] = transpose_attribute;
|
||||
(*batch_matmul_op_info.mutable_attr())["transpose_b"] = transpose_attribute;
|
||||
|
||||
OpInfo::TensorProperties* a_matrix = batch_matmul_op_info.add_inputs();
|
||||
TensorShapeProto* a_matrix_shape = a_matrix->mutable_shape();
|
||||
a_matrix->set_dtype(a_input.dtype());
|
||||
|
||||
OpInfo::TensorProperties* b_matrix = batch_matmul_op_info.add_inputs();
|
||||
b_matrix->set_dtype(b_input.dtype());
|
||||
TensorShapeProto* b_matrix_shape = b_matrix->mutable_shape();
|
||||
|
||||
TensorShapeProto_Dim m_dim;
|
||||
TensorShapeProto_Dim n_dim;
|
||||
TensorShapeProto_Dim k_dim;
|
||||
|
||||
m_dim.set_size(1);
|
||||
n_dim.set_size(1);
|
||||
k_dim.set_size(1);
|
||||
|
||||
for (int i_idx = 0, a_input_str_size = a_input_str.size();
|
||||
i_idx < a_input_str_size; ++i_idx) {
|
||||
if (b_input_str.find(a_input_str[i_idx]) == std::string::npos) {
|
||||
if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
|
||||
m_dim.set_size(m_dim.size() * a_input_shape.dim(i_idx).size());
|
||||
continue;
|
||||
} else if (rhs_str.find(a_input_str[i_idx]) == std::string::npos) {
|
||||
// The dimension does not appear in the RHS, therefore it is a contracting
|
||||
// dimension.
|
||||
k_dim.set_size(k_dim.size() * a_input_shape.dim(i_idx).size());
|
||||
continue;
|
||||
}
|
||||
// It appears in both input operands, therefore we place it as an outer
|
||||
// dimension for the Batch Matmul.
|
||||
*(a_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
|
||||
*(b_matrix_shape->add_dim()) = a_input_shape.dim(i_idx);
|
||||
}
|
||||
for (int i_idx = 0, b_input_str_size = b_input_str.size();
|
||||
i_idx < b_input_str_size; ++i_idx) {
|
||||
if (a_input_str.find(b_input_str[i_idx]) == std::string::npos) {
|
||||
if (rhs_str.find(b_input_str[i_idx]) == std::string::npos) {
|
||||
VLOG(1) << "Missing accurate estimator for op: " << op_info.op();
|
||||
return PredictCostOfAnUnknownOp(op_context);
|
||||
}
|
||||
n_dim.set_size(n_dim.size() * b_input_shape.dim(i_idx).size());
|
||||
}
|
||||
}
|
||||
|
||||
// The two inner-most dimensions of the Batch Matmul are added.
|
||||
*(a_matrix_shape->add_dim()) = m_dim;
|
||||
*(a_matrix_shape->add_dim()) = k_dim;
|
||||
*(b_matrix_shape->add_dim()) = k_dim;
|
||||
*(b_matrix_shape->add_dim()) = n_dim;
|
||||
|
||||
OpContext batch_matmul_op_context = op_context;
|
||||
batch_matmul_op_context.op_info = batch_matmul_op_info;
|
||||
Costs costs = PredictCosts(batch_matmul_op_context);
|
||||
costs.inaccurate = costs.inaccurate || found_unknown_shapes;
|
||||
costs.num_ops_with_unknown_shapes = found_unknown_shapes;
|
||||
@ -1772,7 +1830,7 @@ Costs OpLevelCostEstimator::PredictFusedOp(
|
||||
|
||||
/* static */
|
||||
OpContext OpLevelCostEstimator::FusedChildContext(
|
||||
const OpContext& parent, const string& op_name,
|
||||
const OpContext& parent, const std::string& op_name,
|
||||
const OpInfo::TensorProperties& output,
|
||||
const std::vector<OpInfo::TensorProperties>& inputs) {
|
||||
// Setup the base parameters of our new context.
|
||||
@ -1821,7 +1879,7 @@ OpLevelCostEstimator::OpDimensionsFromInputs(
|
||||
VLOG(2) << "Image shape: " << image_shape.DebugString();
|
||||
|
||||
int x_index, y_index, channel_index;
|
||||
const string& data_format = GetDataFormat(op_info);
|
||||
const std::string& data_format = GetDataFormat(op_info);
|
||||
if (data_format == "NCHW") {
|
||||
channel_index = 1;
|
||||
y_index = 2;
|
||||
|
@ -138,6 +138,9 @@ class OpLevelCostEstimator {
|
||||
static int64 CountMatMulOperations(const OpInfo& op_info,
|
||||
MatMulDimensions* mat_mul,
|
||||
bool* found_unknown_shapes);
|
||||
bool GenerateBatchMatmulContextFromEinsum(const OpContext& einsum_context,
|
||||
OpContext* batch_matmul_context,
|
||||
bool* found_unknown_shapes) const;
|
||||
static int64 CountBatchMatMulOperations(const OpInfo& op_info,
|
||||
bool* found_unknown_shapes);
|
||||
static int64 CountBatchMatMulOperations(const OpInfo& op_info,
|
||||
|
@ -1762,6 +1762,7 @@ tf_cuda_cc_test(
|
||||
name = "conv_ops_test",
|
||||
size = "medium",
|
||||
srcs = ["conv_ops_test.cc"],
|
||||
tags = ["no_cuda11"], # b/159664089
|
||||
deps = [
|
||||
":conv_ops",
|
||||
":image",
|
||||
|
@ -90,7 +90,7 @@ class ResourceConditionalAccumulatorOp : public ConditionalAccumulatorBaseOp {
|
||||
h(1) = cinfo_.name();
|
||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||
ctx, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<ConditionalAccumulatorBase>()));
|
||||
TypeIndex::Make<ConditionalAccumulatorBase>()));
|
||||
}
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(ResourceConditionalAccumulatorOp);
|
||||
|
@ -35,7 +35,7 @@ Status CreateHandle(OpKernelContext* ctx, T* resource,
|
||||
TF_RETURN_IF_ERROR(mgr->Create<T>(container_name, unique_name, resource));
|
||||
|
||||
*handle = MakeResourceHandle(container_name, unique_name, *ctx->device(),
|
||||
MakeTypeIndex<T>());
|
||||
TypeIndex::Make<T>());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -111,7 +111,7 @@ class ThreadPoolHandleOp : public OpKernel {
|
||||
}
|
||||
OP_REQUIRES_OK(ctx, MakeResourceHandleToOutput(
|
||||
ctx, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<ThreadPoolResource>()));
|
||||
TypeIndex::Make<ThreadPoolResource>()));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -443,7 +443,7 @@ void IteratorHandleOp::Compute(OpKernelContext* context)
|
||||
}
|
||||
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
|
||||
context, 0, cinfo_.container(), cinfo_.name(),
|
||||
MakeTypeIndex<IteratorResource>()));
|
||||
TypeIndex::Make<IteratorResource>()));
|
||||
}
|
||||
|
||||
Status IteratorHandleOp::VerifyResource(IteratorResource* resource) {
|
||||
|
@ -475,7 +475,7 @@ class MultiDeviceIteratorHandleOp : public OpKernel {
|
||||
}
|
||||
OP_REQUIRES_OK(context, MakeResourceHandleToOutput(
|
||||
context, 0, container_name, unique_name,
|
||||
MakeTypeIndex<MultiDeviceIterator>()));
|
||||
TypeIndex::Make<MultiDeviceIterator>()));
|
||||
}
|
||||
|
||||
private:
|
||||
|
@ -126,7 +126,7 @@ class OpsTestBase : public ::testing::Test {
|
||||
std::string container_name =
|
||||
container.empty() ? rm->default_container() : container;
|
||||
EXPECT_TRUE(rm->Create(container_name, name, resource).ok());
|
||||
AddResourceInputInternal(container_name, name, MakeTypeIndex<T>());
|
||||
AddResourceInputInternal(container_name, name, TypeIndex::Make<T>());
|
||||
}
|
||||
|
||||
// Runs an operation producing 'num_outputs' outputs.
|
||||
|
@ -554,7 +554,7 @@ inline void TileGradientOp<Device, Tmultiples>::HandleCase(
|
||||
OpKernelContext* context, const std::vector<Tmultiples>& input_dims,
|
||||
const gtl::ArraySlice<Tmultiples>& multiples_array, Tensor* result) {
|
||||
LOG(FATAL) << "TileGradientOp: Invalid combination of Device, DT and NDIM: "
|
||||
<< MakeTypeIndex<Device>().name() << ", " << DataTypeString(DT)
|
||||
<< TypeIndex::Make<Device>().name() << ", " << DataTypeString(DT)
|
||||
<< ", " << NDIM;
|
||||
}
|
||||
|
||||
|
39
tensorflow/core/ops/compat/ops_history_v2/DecodeImage.pbtxt
Normal file
39
tensorflow/core/ops/compat/ops_history_v2/DecodeImage.pbtxt
Normal file
@ -0,0 +1,39 @@
|
||||
op {
|
||||
name: "DecodeImage"
|
||||
input_arg {
|
||||
name: "contents"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "image"
|
||||
type_attr: "dtype"
|
||||
}
|
||||
attr {
|
||||
name: "channels"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_UINT8
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "expand_animations"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
@ -11476,6 +11476,45 @@ op {
|
||||
type: DT_UINT8
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "DecodeImage"
|
||||
input_arg {
|
||||
name: "contents"
|
||||
type: DT_STRING
|
||||
}
|
||||
output_arg {
|
||||
name: "image"
|
||||
type_attr: "dtype"
|
||||
}
|
||||
attr {
|
||||
name: "channels"
|
||||
type: "int"
|
||||
default_value {
|
||||
i: 0
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "dtype"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_UINT8
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_FLOAT
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "expand_animations"
|
||||
type: "bool"
|
||||
default_value {
|
||||
b: true
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "DecodeJSONExample"
|
||||
input_arg {
|
||||
|
@ -23,14 +23,14 @@ namespace tensorflow {
|
||||
struct MyRandomPODType {};
|
||||
|
||||
TEST(AbiTest, AbiDemangleTest) {
|
||||
EXPECT_EQ(port::MaybeAbiDemangle(MakeTypeIndex<int>().name()), "int");
|
||||
EXPECT_EQ(port::MaybeAbiDemangle(TypeIndex::Make<int>().name()), "int");
|
||||
|
||||
#ifdef PLATFORM_WINDOWS
|
||||
const char pod_type_name[] = "struct tensorflow::MyRandomPODType";
|
||||
#else
|
||||
const char pod_type_name[] = "tensorflow::MyRandomPODType";
|
||||
#endif
|
||||
EXPECT_EQ(port::MaybeAbiDemangle(MakeTypeIndex<MyRandomPODType>().name()),
|
||||
EXPECT_EQ(port::MaybeAbiDemangle(TypeIndex::Make<MyRandomPODType>().name()),
|
||||
pod_type_name);
|
||||
|
||||
EXPECT_EQ(
|
||||
|
@ -136,7 +136,7 @@ typedef struct TF_TString { // NOLINT
|
||||
// _Static_assert(CHAR_BIT == 8);
|
||||
// _Static_assert(sizeof(TF_TString) == 24);
|
||||
|
||||
extern inline TF_TString_Type TF_TString_GetType(const TF_TString *str) {
|
||||
static inline TF_TString_Type TF_TString_GetType(const TF_TString *str) {
|
||||
return (TF_TString_Type)(str->u.raw.raw[0] & TF_TSTR_TYPE_MASK); // NOLINT
|
||||
}
|
||||
|
||||
@ -168,12 +168,12 @@ static inline size_t TF_TString_ToInternalSizeT(size_t size,
|
||||
#endif // TF_TSTRING_LITTLE_ENDIAN
|
||||
}
|
||||
|
||||
extern inline void TF_TString_Init(TF_TString *str) {
|
||||
static inline void TF_TString_Init(TF_TString *str) {
|
||||
str->u.smll.size = 0;
|
||||
str->u.smll.str[0] = '\0';
|
||||
}
|
||||
|
||||
extern inline void TF_TString_Dealloc(TF_TString *str) {
|
||||
static inline void TF_TString_Dealloc(TF_TString *str) {
|
||||
if (TF_TString_GetType(str) == TF_TSTR_LARGE &&
|
||||
str->u.large.ptr != NULL) { // NOLINT
|
||||
free(str->u.large.ptr);
|
||||
@ -181,7 +181,7 @@ extern inline void TF_TString_Dealloc(TF_TString *str) {
|
||||
}
|
||||
}
|
||||
|
||||
extern inline size_t TF_TString_GetSize(const TF_TString *str) {
|
||||
static inline size_t TF_TString_GetSize(const TF_TString *str) {
|
||||
switch (TF_TString_GetType(str)) {
|
||||
case TF_TSTR_SMALL:
|
||||
return str->u.smll.size >> 2;
|
||||
@ -196,7 +196,7 @@ extern inline size_t TF_TString_GetSize(const TF_TString *str) {
|
||||
}
|
||||
}
|
||||
|
||||
extern inline size_t TF_TString_GetCapacity(const TF_TString *str) {
|
||||
static inline size_t TF_TString_GetCapacity(const TF_TString *str) {
|
||||
switch (TF_TString_GetType(str)) {
|
||||
case TF_TSTR_SMALL:
|
||||
return TF_TString_SmallCapacity;
|
||||
@ -209,7 +209,7 @@ extern inline size_t TF_TString_GetCapacity(const TF_TString *str) {
|
||||
}
|
||||
}
|
||||
|
||||
extern inline const char *TF_TString_GetDataPointer(const TF_TString *str) {
|
||||
static inline const char *TF_TString_GetDataPointer(const TF_TString *str) {
|
||||
switch (TF_TString_GetType(str)) {
|
||||
case TF_TSTR_SMALL:
|
||||
return str->u.smll.str;
|
||||
@ -225,7 +225,7 @@ extern inline const char *TF_TString_GetDataPointer(const TF_TString *str) {
|
||||
}
|
||||
}
|
||||
|
||||
extern inline char *TF_TString_ResizeUninitialized(TF_TString *str,
|
||||
static inline char *TF_TString_ResizeUninitialized(TF_TString *str,
|
||||
size_t new_size) {
|
||||
size_t curr_size = TF_TString_GetSize(str);
|
||||
size_t copy_size = TF_min(new_size, curr_size);
|
||||
@ -288,7 +288,7 @@ extern inline char *TF_TString_ResizeUninitialized(TF_TString *str,
|
||||
return str->u.large.ptr;
|
||||
}
|
||||
|
||||
extern inline char *TF_TString_GetMutableDataPointer(TF_TString *str) {
|
||||
static inline char *TF_TString_GetMutableDataPointer(TF_TString *str) {
|
||||
switch (TF_TString_GetType(str)) {
|
||||
case TF_TSTR_SMALL:
|
||||
return str->u.smll.str;
|
||||
@ -306,7 +306,7 @@ extern inline char *TF_TString_GetMutableDataPointer(TF_TString *str) {
|
||||
}
|
||||
}
|
||||
|
||||
extern inline void TF_TString_Reserve(TF_TString *str, size_t new_cap) {
|
||||
static inline void TF_TString_Reserve(TF_TString *str, size_t new_cap) {
|
||||
TF_TString_Type curr_type = TF_TString_GetType(str);
|
||||
|
||||
if (new_cap <= TF_TString_SmallCapacity) {
|
||||
@ -347,7 +347,7 @@ extern inline void TF_TString_Reserve(TF_TString *str, size_t new_cap) {
|
||||
str->u.large.cap = new_cap;
|
||||
}
|
||||
|
||||
extern inline char *TF_TString_Resize(TF_TString *str, size_t new_size,
|
||||
static inline char *TF_TString_Resize(TF_TString *str, size_t new_size,
|
||||
char c) {
|
||||
size_t curr_size = TF_TString_GetSize(str);
|
||||
char *cstr = TF_TString_ResizeUninitialized(str, new_size);
|
||||
@ -359,7 +359,7 @@ extern inline char *TF_TString_Resize(TF_TString *str, size_t new_size,
|
||||
return cstr;
|
||||
}
|
||||
|
||||
extern inline void TF_TString_AssignView(TF_TString *dst, const char *src,
|
||||
static inline void TF_TString_AssignView(TF_TString *dst, const char *src,
|
||||
size_t size) {
|
||||
TF_TString_Dealloc(dst);
|
||||
|
||||
@ -367,7 +367,7 @@ extern inline void TF_TString_AssignView(TF_TString *dst, const char *src,
|
||||
dst->u.view.ptr = src;
|
||||
}
|
||||
|
||||
extern inline void TF_TString_AppendN(TF_TString *dst, const char *src,
|
||||
static inline void TF_TString_AppendN(TF_TString *dst, const char *src,
|
||||
size_t src_size) {
|
||||
if (!src_size) return;
|
||||
|
||||
@ -378,21 +378,21 @@ extern inline void TF_TString_AppendN(TF_TString *dst, const char *src,
|
||||
memcpy(dst_c + dst_size, src, src_size);
|
||||
}
|
||||
|
||||
extern inline void TF_TString_Append(TF_TString *dst, const TF_TString *src) {
|
||||
static inline void TF_TString_Append(TF_TString *dst, const TF_TString *src) {
|
||||
const char *src_c = TF_TString_GetDataPointer(src);
|
||||
size_t size = TF_TString_GetSize(src);
|
||||
|
||||
TF_TString_AppendN(dst, src_c, size);
|
||||
}
|
||||
|
||||
extern inline void TF_TString_Copy(TF_TString *dst, const char *src,
|
||||
static inline void TF_TString_Copy(TF_TString *dst, const char *src,
|
||||
size_t size) {
|
||||
char *dst_c = TF_TString_ResizeUninitialized(dst, size);
|
||||
|
||||
if (size) memcpy(dst_c, src, size);
|
||||
}
|
||||
|
||||
extern inline void TF_TString_Assign(TF_TString *dst, const TF_TString *src) {
|
||||
static inline void TF_TString_Assign(TF_TString *dst, const TF_TString *src) {
|
||||
if (dst == src) return;
|
||||
|
||||
TF_TString_Dealloc(dst);
|
||||
@ -421,7 +421,7 @@ extern inline void TF_TString_Assign(TF_TString *dst, const TF_TString *src) {
|
||||
}
|
||||
}
|
||||
|
||||
extern inline void TF_TString_Move(TF_TString *dst, TF_TString *src) {
|
||||
static inline void TF_TString_Move(TF_TString *dst, TF_TString *src) {
|
||||
if (dst == src) return;
|
||||
|
||||
TF_TString_Dealloc(dst);
|
||||
|
@ -1518,6 +1518,7 @@ Status CuptiTracer::DisableActivityTracing() {
|
||||
|
||||
Status CuptiTracer::Finalize() {
|
||||
if (option_->cupti_finalize) {
|
||||
VLOG(1) << "CuptiFinalize";
|
||||
RETURN_IF_CUPTI_ERROR(cupti_interface_->Finalize());
|
||||
}
|
||||
return Status::OK();
|
||||
|
@ -612,8 +612,11 @@ Status GpuTracer::DoStart() {
|
||||
options_.activities_selected.push_back(CUPTI_ACTIVITY_KIND_MEMCPY2);
|
||||
options_.activities_selected.push_back(CUPTI_ACTIVITY_KIND_OVERHEAD);
|
||||
|
||||
// CUDA/CUPTI 10 have issues (leaks and crashes) with CuptiFinalize.
|
||||
#if CUDA_VERSION < 10000
|
||||
if (!trace_concurrent_kernels) options_.cupti_finalize = true;
|
||||
if (!options.trace_concurrent_kernels()) options_.cupti_finalize = true;
|
||||
#elif CUDA_VERSION >= 11000
|
||||
options_.cupti_finalize = true;
|
||||
#endif
|
||||
|
||||
CuptiTracerCollectorOptions collector_options;
|
||||
|
@ -141,7 +141,6 @@ cc_library(
|
||||
"//tensorflow/core/tpu/kernels:tpu_util_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_executor_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_node_context_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:tpu_platform_hdrs",
|
||||
],
|
||||
)
|
||||
|
||||
|
@ -1,4 +1,5 @@
|
||||
# TPU Kernel Implementations
|
||||
|
||||
load(
|
||||
"//tensorflow/core/platform:build_config.bzl",
|
||||
"tf_proto_library_cc",
|
||||
@ -86,8 +87,8 @@ cc_library(
|
||||
hdrs = ["tpu_compile_c_api.h"],
|
||||
deps = [
|
||||
":tpu_mesh_state_c_api_hdrs",
|
||||
":tpu_ops_common_c_api_hdrs",
|
||||
":tpu_program_c_api_hdrs",
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/core/tpu:libtftpu_header",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
@ -367,7 +368,6 @@ cc_library(
|
||||
|
||||
cc_library(
|
||||
name = "tpu_util_hdrs",
|
||||
srcs = [],
|
||||
hdrs = ["tpu_util.h"],
|
||||
deps = [
|
||||
":tpu_compilation_cache_key",
|
||||
@ -390,17 +390,11 @@ cc_library(
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_ops_common_c_api_hdrs",
|
||||
hdrs = ["tpu_ops_common_c_api.h"],
|
||||
alwayslink = True,
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "tpu_program_c_api_hdrs",
|
||||
hdrs = ["tpu_program_c_api.h"],
|
||||
deps = [
|
||||
":tpu_ops_common_c_api_hdrs",
|
||||
":tpu_util_c_api_hdrs",
|
||||
"//tensorflow/stream_executor/tpu:proto_helper",
|
||||
],
|
||||
alwayslink = True,
|
||||
|
@ -16,8 +16,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_COMPILE_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_mesh_state_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_program_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/core/tpu/libtftpu.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
|
@ -1,20 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_OPS_COMMON_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_OPS_COMMON_C_API_H_
|
||||
|
||||
typedef struct SE_Status SE_Status;
|
||||
|
||||
#endif // TENSORFLOW_CORE_TPU_KERNELS_TPU_OPS_COMMON_C_API_H_
|
@ -15,7 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_
|
||||
#define TENSORFLOW_CORE_TPU_KERNELS_TPU_PROGRAM_C_API_H_
|
||||
|
||||
#include "tensorflow/core/tpu/kernels/tpu_ops_common_c_api.h"
|
||||
#include "tensorflow/core/tpu/kernels/tpu_util_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/proto_helper.h"
|
||||
|
||||
typedef struct XLA_TpuProgram XLA_TpuProgram;
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
#include "tensorflow/core/tpu/tpu_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_node_context_c_api.h"
|
||||
#include "tensorflow/stream_executor/tpu/tpu_platform.h"
|
||||
|
||||
#define TFTPU_SET_FN(Struct, FnName) \
|
||||
Struct->FnName##Fn = \
|
||||
|
@ -137,6 +137,7 @@ tensorflow::Status SetTpuNodeContextStructFns(void* library_handle) {
|
||||
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Create);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Free);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_Initialize);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_StopChipHeartbeats);
|
||||
TFTPU_SET_FN(node_context_fn, TpuNodeContext_CloseTpuHost);
|
||||
|
||||
|
@ -15370,6 +15370,80 @@ func MergeSummary(scope *Scope, inputs []tf.Output) (summary tf.Output) {
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// DecodeImageAttr is an optional argument to DecodeImage.
|
||||
type DecodeImageAttr func(optionalAttr)
|
||||
|
||||
// DecodeImageChannels sets the optional channels attribute to value.
|
||||
//
|
||||
// value: Number of color channels for the decoded image.
|
||||
// If not specified, defaults to 0
|
||||
func DecodeImageChannels(value int64) DecodeImageAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["channels"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeImageDtype sets the optional dtype attribute to value.
|
||||
//
|
||||
// value: The desired DType of the returned Tensor.
|
||||
// If not specified, defaults to DT_UINT8
|
||||
func DecodeImageDtype(value tf.DataType) DecodeImageAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["dtype"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// DecodeImageExpandAnimations sets the optional expand_animations attribute to value.
|
||||
//
|
||||
// value: Controls the output shape of the returned op. If True, the returned op will
|
||||
// produce a 3-D tensor for PNG, JPEG, and BMP files; and a 4-D tensor for all
|
||||
// GIFs, whether animated or not. If, False, the returned op will produce a 3-D
|
||||
// tensor for all file types and will truncate animated GIFs to the first frame.
|
||||
// If not specified, defaults to true
|
||||
func DecodeImageExpandAnimations(value bool) DecodeImageAttr {
|
||||
return func(m optionalAttr) {
|
||||
m["expand_animations"] = value
|
||||
}
|
||||
}
|
||||
|
||||
// Function for decode_bmp, decode_gif, decode_jpeg, and decode_png.
|
||||
//
|
||||
// Detects whether an image is a BMP, GIF, JPEG, or PNG, and performs the
|
||||
// appropriate operation to convert the input bytes string into a Tensor of type
|
||||
// dtype.
|
||||
//
|
||||
// *NOTE*: decode_gif returns a 4-D array [num_frames, height, width, 3], as
|
||||
// opposed to decode_bmp, decode_jpeg and decode_png, which return 3-D arrays
|
||||
// [height, width, num_channels]. Make sure to take this into account when
|
||||
// constructing your graph if you are intermixing GIF files with BMP, JPEG, and/or
|
||||
// PNG files. Alternately, set the expand_animations argument of this function to
|
||||
// False, in which case the op will return 3-dimensional tensors and will truncate
|
||||
// animated GIF files to the first frame.
|
||||
//
|
||||
// Arguments:
|
||||
// contents: 0-D. The encoded image bytes.
|
||||
//
|
||||
// Returns 3-D with shape `[height, width, channels]` or 4-D with shape
|
||||
// `[frame, height, width, channels]`..
|
||||
func DecodeImage(scope *Scope, contents tf.Output, optional ...DecodeImageAttr) (image tf.Output) {
|
||||
if scope.Err() != nil {
|
||||
return
|
||||
}
|
||||
attrs := map[string]interface{}{}
|
||||
for _, a := range optional {
|
||||
a(attrs)
|
||||
}
|
||||
opspec := tf.OpSpec{
|
||||
Type: "DecodeImage",
|
||||
Input: []tf.Input{
|
||||
contents,
|
||||
},
|
||||
Attrs: attrs,
|
||||
}
|
||||
op := scope.AddOperation(opspec)
|
||||
return op.Output(0)
|
||||
}
|
||||
|
||||
// AvgPoolAttr is an optional argument to AvgPool.
|
||||
type AvgPoolAttr func(optionalAttr)
|
||||
|
||||
|
@ -177,6 +177,92 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
||||
}
|
||||
}
|
||||
|
||||
// We have this parse function instead of directly returning kTfLiteOk from the
|
||||
// switch-case in ParseOpData because this function is used as part of the
|
||||
// selective registration for the OpResolver implementation in micro.
|
||||
TfLiteStatus ParseAbs(const Operator*, BuiltinOperator, ErrorReporter*,
|
||||
BuiltinDataAllocator*, void**) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseAdd(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteAddParams, SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteAddParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const AddOptions* schema_params = op->builtin_options_as_AddOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
params->activation =
|
||||
ConvertActivation(schema_params->fused_activation_function());
|
||||
params->pot_scale_int16 = schema_params->pot_scale_int16();
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseArgMax(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteArgMaxParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteArgMaxParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const ArgMaxOptions* schema_params = op->builtin_options_as_ArgMaxOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
TF_LITE_ENSURE_STATUS(ConvertTensorType(
|
||||
schema_params->output_type(), ¶ms->output_type, error_reporter));
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
CheckParsePointerParams(op, error_reporter, allocator, builtin_data);
|
||||
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
std::unique_ptr<TfLiteArgMinParams,
|
||||
SafeBuiltinDataAllocator::BuiltinDataDeleter>
|
||||
params = safe_allocator.Allocate<TfLiteArgMinParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
|
||||
const ArgMinOptions* schema_params = op->builtin_options_as_ArgMinOptions();
|
||||
|
||||
if (schema_params != nullptr) {
|
||||
TF_LITE_ENSURE_STATUS(ConvertTensorType(
|
||||
schema_params->output_type(), ¶ms->output_type, error_reporter));
|
||||
} else {
|
||||
// TODO(b/157480169): We should either return kTfLiteError or fill in some
|
||||
// reasonable defaults in the params struct. We are not doing so until we
|
||||
// better undertand the ramifications of changing the legacy behavior.
|
||||
}
|
||||
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data) {
|
||||
@ -430,6 +516,22 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
SafeBuiltinDataAllocator safe_allocator(allocator);
|
||||
*builtin_data = nullptr;
|
||||
switch (op_type) {
|
||||
case BuiltinOperator_ABS: {
|
||||
return ParseAbs(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_ADD: {
|
||||
return ParseAdd(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_ARG_MAX: {
|
||||
return ParseArgMax(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_ARG_MIN: {
|
||||
return ParseArgMin(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
|
||||
case BuiltinOperator_CONV_2D: {
|
||||
return ParseConv2D(op, op_type, error_reporter, allocator, builtin_data);
|
||||
}
|
||||
@ -586,17 +688,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_ADD: {
|
||||
auto params = safe_allocator.Allocate<TfLiteAddParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* schema_params = op->builtin_options_as_AddOptions()) {
|
||||
params->activation =
|
||||
ConvertActivation(schema_params->fused_activation_function());
|
||||
params->pot_scale_int16 = schema_params->pot_scale_int16();
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_DIV: {
|
||||
auto params = safe_allocator.Allocate<TfLiteDivParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
@ -840,28 +931,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_ARG_MAX: {
|
||||
auto params = safe_allocator.Allocate<TfLiteArgMaxParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* schema_params = op->builtin_options_as_ArgMaxOptions()) {
|
||||
TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->output_type(),
|
||||
¶ms->output_type,
|
||||
error_reporter));
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_ARG_MIN: {
|
||||
auto params = safe_allocator.Allocate<TfLiteArgMinParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
if (const auto* schema_params = op->builtin_options_as_ArgMinOptions()) {
|
||||
TF_LITE_ENSURE_STATUS(ConvertTensorType(schema_params->output_type(),
|
||||
¶ms->output_type,
|
||||
error_reporter));
|
||||
}
|
||||
*builtin_data = params.release();
|
||||
return kTfLiteOk;
|
||||
}
|
||||
case BuiltinOperator_TRANSPOSE_CONV: {
|
||||
auto params = safe_allocator.Allocate<TfLiteTransposeConvParams>();
|
||||
TF_LITE_ENSURE(error_reporter, params != nullptr);
|
||||
@ -1021,7 +1090,6 @@ TfLiteStatus ParseOpData(const Operator* op, BuiltinOperator op_type,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
// Below are the ops with no builtin_data structure.
|
||||
case BuiltinOperator_ABS:
|
||||
case BuiltinOperator_BATCH_TO_SPACE_ND:
|
||||
// TODO(aselle): Implement call in BuiltinOptions, but nullptrs are
|
||||
// ok for now, since there is no call implementation either.
|
||||
|
@ -75,6 +75,22 @@ TfLiteStatus ConvertTensorType(TensorType tensor_type, TfLiteType* type,
|
||||
// removed once we are no longer using ParseOpData for the OpResolver
|
||||
// implementation in micro.
|
||||
|
||||
TfLiteStatus ParseAbs(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseAdd(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseArgMax(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseArgMin(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
||||
TfLiteStatus ParseConv2D(const Operator* op, BuiltinOperator op_type,
|
||||
ErrorReporter* error_reporter,
|
||||
BuiltinDataAllocator* allocator, void** builtin_data);
|
||||
|
@ -27,21 +27,19 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GenerateConvolutionTransposedCode(
|
||||
const OperationDef& op_def, const LinearStorage& biases, int src_depth,
|
||||
int dst_depth, const CLDevice& device,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations) {
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data",
|
||||
WHSBPoint{"src_size.x", "src_size.y", "src_size.z", "src_size.w"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data",
|
||||
WHSBPoint{"dst_size.x", "dst_size.y", "dst_size.z", "dst_size.w"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GenerateConvolutionTransposedCode(const OperationDef& op_def,
|
||||
int src_depth, int dst_depth,
|
||||
const CLDevice& device,
|
||||
Arguments* args) {
|
||||
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
|
||||
src_desc->SetTextureAddressMode(GetFastestZeroMode(device));
|
||||
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
|
||||
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
|
||||
|
||||
const std::string batch_id = op_def.IsBatchSupported() ? "B" : "";
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
|
||||
switch (op_def.precision) {
|
||||
@ -61,23 +59,19 @@ std::string GenerateConvolutionTransposedCode(
|
||||
}
|
||||
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
||||
c += " __constant FLT4* filters, \n";
|
||||
c += biases.GetDeclaration();
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 src_size, \n";
|
||||
c += " int4 dst_size \n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
if (op_def.IsBatchSupported()) {
|
||||
c += " int linear_id = get_global_id(0);\n";
|
||||
c += " int X = linear_id / dst_size.w;\n";
|
||||
c += " int B = linear_id % dst_size.w;\n";
|
||||
c += " int X = linear_id / args.dst_tensor.Batch();\n";
|
||||
c += " int B = linear_id % args.dst_tensor.Batch();\n";
|
||||
c += " args.dst_tensor.SetBatchRef(B);\n";
|
||||
c += " args.src_tensor.SetBatchRef(B);\n";
|
||||
} else {
|
||||
c += " int X = get_global_id(0);\n";
|
||||
}
|
||||
c += " int Y = get_global_id(1);\n";
|
||||
c += " if (X >= src_size.x || Y >= src_size.y) return;\n";
|
||||
c += " if (X >= args.src_tensor.Width() || Y >= args.src_tensor.Height()) "
|
||||
"return;\n";
|
||||
for (int d = 0; d < dst_depth; ++d) {
|
||||
const std::string layer = std::to_string(d);
|
||||
c += " ACCUM_FLT4 r" + layer + "[2][2];\n";
|
||||
@ -91,61 +85,48 @@ std::string GenerateConvolutionTransposedCode(
|
||||
const std::string z = std::to_string(s);
|
||||
c += " {\n";
|
||||
if (src_tensor_type == TensorStorageType::BUFFER) {
|
||||
c += " bool x_in = X + 1 < src_size.x;\n";
|
||||
c += " bool y_in = Y + 1 < src_size.y;\n";
|
||||
c +=
|
||||
" FLT4 src0 = " + src_tensor.ReadWHSB("X", "Y", z, batch_id) + ";\n";
|
||||
c += " bool x_in = X + 1 < args.src_tensor.Width();\n";
|
||||
c += " bool y_in = Y + 1 < args.src_tensor.Height();\n";
|
||||
c += " FLT4 src0 = args.src_tensor.Read(X, Y, " + z + ");\n";
|
||||
c += " FLT4 src1 = (FLT4)(0.0);\n";
|
||||
c += " FLT4 src2 = (FLT4)(0.0);\n";
|
||||
c += " FLT4 src3 = (FLT4)(0.0);\n";
|
||||
c += " if (x_in) {\n";
|
||||
c += " src1 = " + src_tensor.ReadWHSB("X + 1", "Y", z, batch_id) +
|
||||
";\n";
|
||||
c += " src1 = args.src_tensor.Read(X + 1, Y, " + z + ");\n";
|
||||
c += " }\n";
|
||||
c += " if (y_in) {\n";
|
||||
c += " src2 = " + src_tensor.ReadWHSB("X", "Y + 1", z, batch_id) +
|
||||
";\n";
|
||||
c += " src2 = args.src_tensor.Read(X, Y + 1, " + z + ");\n";
|
||||
c += " }\n";
|
||||
c += " if (x_in && y_in) {\n";
|
||||
c += " src3 = " + src_tensor.ReadWHSB("X + 1", "Y + 1", z, batch_id) +
|
||||
";\n";
|
||||
c += " src3 = args.src_tensor.Read(X + 1, Y + 1, " + z + ");\n";
|
||||
c += " }\n";
|
||||
} else if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) {
|
||||
c +=
|
||||
" " + src_tensor.GetAddressWHSB("c0", "X", "Y", z, batch_id) + ";\n";
|
||||
c += " " + src_tensor.GetAddressWHSB("c1", "X + 1", "Y", z, batch_id) +
|
||||
";\n";
|
||||
c += " " + src_tensor.GetAddressWHSB("c2", "X", "Y + 1", z, batch_id) +
|
||||
";\n";
|
||||
c += " " +
|
||||
src_tensor.GetAddressWHSB("c3", "X + 1", "Y + 1", z, batch_id) +
|
||||
";\n";
|
||||
c += " bool x_in = X + 1 < src_size.x;\n";
|
||||
c += " bool y_in = Y + 1 < src_size.y;\n";
|
||||
c += " args.src_tensor.GetAddress(c0, X, Y, " + z + ");\n";
|
||||
c += " args.src_tensor.GetAddress(c1, X + 1, Y, " + z + ");\n";
|
||||
c += " args.src_tensor.GetAddress(c2, X, Y + 1, " + z + ");\n";
|
||||
c += " args.src_tensor.GetAddress(c3, X + 1, Y + 1, " + z + ");\n";
|
||||
c += " bool x_in = X + 1 < args.src_tensor.Width();\n";
|
||||
c += " bool y_in = Y + 1 < args.src_tensor.Height();\n";
|
||||
c += " c1 = select(-1, c1, x_in);\n";
|
||||
c += " c2 = select(-1, c2, y_in);\n";
|
||||
c += " c3 = select(-1, c3, x_in && y_in);\n";
|
||||
c += " FLT4 src0 = " + src_tensor.Read("c0") + ";\n";
|
||||
c += " FLT4 src1 = " + src_tensor.Read("c1") + ";\n";
|
||||
c += " FLT4 src2 = " + src_tensor.Read("c2") + ";\n";
|
||||
c += " FLT4 src3 = " + src_tensor.Read("c3") + ";\n";
|
||||
c += " FLT4 src0 = args.src_tensor.Read(c0);\n";
|
||||
c += " FLT4 src1 = args.src_tensor.Read(c1);\n";
|
||||
c += " FLT4 src2 = args.src_tensor.Read(c2);\n";
|
||||
c += " FLT4 src3 = args.src_tensor.Read(c3);\n";
|
||||
} else {
|
||||
const auto mode = GetFastestZeroMode(device);
|
||||
c += " FLT4 src0 = " + src_tensor.ReadWHSB("X", "Y", z, batch_id, mode) +
|
||||
";\n";
|
||||
c += " FLT4 src1 = " +
|
||||
src_tensor.ReadWHSB("X + 1", "Y", z, batch_id, mode) + ";\n";
|
||||
c += " FLT4 src2 = " +
|
||||
src_tensor.ReadWHSB("X", "Y + 1", z, batch_id, mode) + ";\n";
|
||||
c += " FLT4 src3 = " +
|
||||
src_tensor.ReadWHSB("X + 1", "Y + 1", z, batch_id, mode) + ";\n";
|
||||
c += " FLT4 src0 = args.src_tensor.Read(X, Y, " + z + ");\n";
|
||||
c += " FLT4 src1 = args.src_tensor.Read(X + 1, Y, " + z + ");\n";
|
||||
c += " FLT4 src2 = args.src_tensor.Read(X, Y + 1, " + z + ");\n";
|
||||
c += " FLT4 src3 = args.src_tensor.Read(X + 1, Y + 1, " + z + ");\n";
|
||||
}
|
||||
for (int d = 0; d < dst_depth; ++d) {
|
||||
const std::string layer = std::to_string(d);
|
||||
const std::string f_offset = std::to_string(filters_index);
|
||||
filters_index++;
|
||||
c += " {\n";
|
||||
c += " __constant FLT4* L0 = filters + 36 * " + f_offset + ";\n";
|
||||
c += " __constant FLT4* L0 = args.weights.GetPtr() + 36 * " + f_offset +
|
||||
";\n";
|
||||
c += " CONV(r" + layer + "[0][0], src0, L0, 0);\n";
|
||||
c += " CONV(r" + layer + "[0][1], src0, L0, 4);\n";
|
||||
c += " CONV(r" + layer + "[0][1], src1, L0, 8);\n";
|
||||
@ -164,7 +145,8 @@ std::string GenerateConvolutionTransposedCode(
|
||||
for (int d = 0; d < dst_depth; ++d) {
|
||||
const std::string layer = std::to_string(d);
|
||||
c += " {\n";
|
||||
c += " FLT4 bias_val = " + biases.ReadLinearFLT4(layer) + ";\n";
|
||||
c += " FLT4 bias_val = args.weights.Read(" +
|
||||
std::to_string(36 * filters_index + d) + ");\n";
|
||||
for (int y = 0; y < 2; ++y) {
|
||||
for (int x = 0; x < 2; ++x) {
|
||||
const std::string x_coord = "X + " + std::to_string(x);
|
||||
@ -172,14 +154,8 @@ std::string GenerateConvolutionTransposedCode(
|
||||
c += " {\n";
|
||||
c += " FLT4 result = TO_FLT4(r" + layer + "[" + std::to_string(y) +
|
||||
"][" + std::to_string(x) + "]) + bias_val;\n";
|
||||
const std::string x_3dcoord = op_def.IsBatchSupported()
|
||||
? "(" + x_coord + ") * dst_size.w + B"
|
||||
: x_coord;
|
||||
const LinkingContext context{"result", x_3dcoord, y_coord, layer};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " +
|
||||
dst_tensor.WriteWHSB("result", x_coord, y_coord, layer, batch_id) +
|
||||
"\n";
|
||||
c += " args.dst_tensor.Write(result, " + x_coord + ", " + y_coord +
|
||||
", " + layer + ");\n";
|
||||
c += " }\n";
|
||||
}
|
||||
}
|
||||
@ -200,8 +176,6 @@ ConvolutionTransposed3x3Thin::ConvolutionTransposed3x3Thin(
|
||||
ConvolutionTransposed3x3Thin::ConvolutionTransposed3x3Thin(
|
||||
ConvolutionTransposed3x3Thin&& operation)
|
||||
: GPUOperation(std::move(operation)),
|
||||
weights_(std::move(operation.weights_)),
|
||||
biases_(std::move(operation.biases_)),
|
||||
src_channels_(operation.src_channels_),
|
||||
dst_channels_(operation.dst_channels_),
|
||||
kernel_(std::move(operation.kernel_)),
|
||||
@ -210,8 +184,6 @@ ConvolutionTransposed3x3Thin::ConvolutionTransposed3x3Thin(
|
||||
ConvolutionTransposed3x3Thin& ConvolutionTransposed3x3Thin::operator=(
|
||||
ConvolutionTransposed3x3Thin&& operation) {
|
||||
if (this != &operation) {
|
||||
weights_ = std::move(operation.weights_);
|
||||
biases_ = std::move(operation.biases_);
|
||||
std::swap(src_channels_, operation.src_channels_);
|
||||
std::swap(dst_channels_, operation.dst_channels_);
|
||||
kernel_ = std::move(operation.kernel_);
|
||||
@ -223,25 +195,25 @@ ConvolutionTransposed3x3Thin& ConvolutionTransposed3x3Thin::operator=(
|
||||
|
||||
absl::Status ConvolutionTransposed3x3Thin::Compile(
|
||||
const CreationContext& creation_context) {
|
||||
const auto code = GenerateConvolutionTransposedCode(
|
||||
definition_, biases_, DivideRoundUp(src_channels_, 4),
|
||||
DivideRoundUp(dst_channels_, 4), *creation_context.device,
|
||||
linked_operations_);
|
||||
std::string code = GenerateConvolutionTransposedCode(
|
||||
definition_, DivideRoundUp(src_channels_, 4),
|
||||
DivideRoundUp(dst_channels_, 4), *creation_context.device, &args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
return creation_context.cache->GetOrCreateCLKernel(
|
||||
code, "main_function", *creation_context.context,
|
||||
*creation_context.device, &kernel_);
|
||||
}
|
||||
|
||||
absl::Status ConvolutionTransposed3x3Thin::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(biases_.GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(src_[0]->GetWHSB()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 ConvolutionTransposed3x3Thin::GetGridSize() const {
|
||||
@ -282,15 +254,7 @@ absl::Status CreateConvolutionTransposed3x3Thin(
|
||||
}
|
||||
*result = ConvolutionTransposed3x3Thin(definition, attr);
|
||||
RETURN_IF_ERROR(
|
||||
result->UploadWeights(attr.weights, creation_context.context));
|
||||
LinearStorageCreateInfo create_info;
|
||||
create_info.storage_type =
|
||||
DeduceLinearStorageType(definition.GetPrimaryStorageType());
|
||||
create_info.data_type = definition.GetDataType();
|
||||
create_info.name = "biases";
|
||||
create_info.aligned_size = attr.weights.shape.o;
|
||||
RETURN_IF_ERROR(CreateLinearStorage(
|
||||
create_info, attr.bias, creation_context.context, &result->biases_));
|
||||
result->UploadData(attr.weights, attr.bias, creation_context.context));
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
|
@ -59,8 +59,9 @@ class ConvolutionTransposed3x3Thin : public GPUOperation {
|
||||
const OperationDef& definition,
|
||||
const ConvolutionTransposedAttributes& attr);
|
||||
template <DataType T>
|
||||
absl::Status UploadWeights(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
CLContext* context);
|
||||
absl::Status UploadData(const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const tflite::gpu::Tensor<Linear, T>& biases,
|
||||
CLContext* context);
|
||||
|
||||
template <DataType S, typename T>
|
||||
void RearrangeWeightsData(const tflite::gpu::Tensor<OHWI, S>& weights,
|
||||
@ -69,9 +70,6 @@ class ConvolutionTransposed3x3Thin : public GPUOperation {
|
||||
absl::Status BindArguments();
|
||||
int3 GetGridSize() const;
|
||||
|
||||
Buffer weights_;
|
||||
LinearStorage biases_;
|
||||
|
||||
int src_channels_;
|
||||
int dst_channels_;
|
||||
|
||||
@ -80,29 +78,58 @@ class ConvolutionTransposed3x3Thin : public GPUOperation {
|
||||
};
|
||||
|
||||
template <DataType T>
|
||||
absl::Status ConvolutionTransposed3x3Thin::UploadWeights(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights, CLContext* context) {
|
||||
absl::Status ConvolutionTransposed3x3Thin::UploadData(
|
||||
const tflite::gpu::Tensor<OHWI, T>& weights,
|
||||
const tflite::gpu::Tensor<Linear, T>& biases, CLContext* context) {
|
||||
const int src_depth = DivideRoundUp(src_channels_, 4);
|
||||
const int dst_depth = DivideRoundUp(dst_channels_, 4);
|
||||
const int kernel_x = 3; // This operation support only 3x3 kernel
|
||||
const int kernel_y = 3;
|
||||
const int flt4_count = kernel_x * kernel_y * src_depth * dst_depth * 4;
|
||||
|
||||
const int flt4_size = definition_.precision == CalculationsPrecision::F32
|
||||
? sizeof(float4)
|
||||
: sizeof(half4);
|
||||
const bool f32_weights = definition_.precision == CalculationsPrecision::F32;
|
||||
|
||||
if (definition_.GetDataType() == DataType::FLOAT32) {
|
||||
BufferDescriptor desc;
|
||||
desc.element_type = f32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
|
||||
desc.element_size = 4;
|
||||
desc.memory_type = MemoryType::CONSTANT;
|
||||
|
||||
Buffer weights_buffer;
|
||||
if (f32_weights) {
|
||||
std::vector<float4> gpu_data(flt4_count);
|
||||
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
|
||||
return CreateReadOnlyBuffer(flt4_size * flt4_count, gpu_data.data(),
|
||||
context, &weights_);
|
||||
for (int i = 0; i < dst_depth; ++i) {
|
||||
float4 bias_value(0.0f);
|
||||
for (int c = 0; c < 4; ++c) {
|
||||
int ch = i * 4 + c;
|
||||
bias_value[c] = ch < weights.shape.o ? biases.data[ch] : 0.0f;
|
||||
}
|
||||
gpu_data.push_back(bias_value);
|
||||
}
|
||||
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(float4) * gpu_data.size(),
|
||||
gpu_data.data(), context,
|
||||
&weights_buffer));
|
||||
} else {
|
||||
std::vector<half4> gpu_data(flt4_count);
|
||||
RearrangeWeightsData(weights, absl::MakeSpan(gpu_data));
|
||||
return CreateReadOnlyBuffer(flt4_size * flt4_count, gpu_data.data(),
|
||||
context, &weights_);
|
||||
for (int i = 0; i < dst_depth; ++i) {
|
||||
half4 bias_value(0.0f);
|
||||
for (int c = 0; c < 4; ++c) {
|
||||
int ch = i * 4 + c;
|
||||
bias_value[c] = ch < weights.shape.o ? biases.data[ch] : 0.0f;
|
||||
}
|
||||
gpu_data.push_back(bias_value);
|
||||
}
|
||||
RETURN_IF_ERROR(CreateReadOnlyBuffer(sizeof(half4) * gpu_data.size(),
|
||||
gpu_data.data(), context,
|
||||
&weights_buffer));
|
||||
}
|
||||
|
||||
args_.AddObject("weights", AccessType::READ,
|
||||
absl::make_unique<Buffer>(std::move(weights_buffer)),
|
||||
absl::make_unique<BufferDescriptor>(desc));
|
||||
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
template <DataType S, typename T>
|
||||
|
@ -43,7 +43,7 @@ TEST_F(OpenCLOperationTest, ConvolutionTransposed3x3ThinSimpleWeights) {
|
||||
attr.weights.shape = OHWI(1, 3, 3, 1);
|
||||
attr.weights.data = {1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f, 1.0f};
|
||||
attr.bias.shape = Linear(2);
|
||||
attr.bias.data = {0.0f};
|
||||
attr.bias.data = {0.0f, 0.0f};
|
||||
|
||||
for (auto storage : env_.GetSupportedStorages()) {
|
||||
for (auto precision : env_.GetSupportedPrecisions()) {
|
||||
|
@ -28,55 +28,47 @@ namespace gpu {
|
||||
namespace cl {
|
||||
namespace {
|
||||
|
||||
std::string GenerateDepthwiseConvCode(
|
||||
const OperationDef& op_def,
|
||||
const std::vector<ElementwiseOperation*>& linked_operations,
|
||||
const CLDevice& device, bool weights_are_buffer, bool local_mem_uploads) {
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
TensorCodeGenerator src_tensor(
|
||||
"src_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.src_tensors[0]);
|
||||
TensorCodeGenerator dst_tensor(
|
||||
"dst_data", WHSPoint{"dst_size.x", "dst_size.y", "dst_size.z"},
|
||||
op_def.dst_tensors[0]);
|
||||
std::string GenerateDepthwiseConvCode(const OperationDef& op_def,
|
||||
const CLDevice& device,
|
||||
bool weights_are_buffer,
|
||||
bool local_mem_uploads, Arguments* args) {
|
||||
auto src_desc = absl::make_unique<TensorDescriptor>(op_def.src_tensors[0]);
|
||||
src_desc->SetTextureAddressMode(GetFastestZeroMode(device));
|
||||
args->AddObjectRef("src_tensor", AccessType::READ, std::move(src_desc));
|
||||
args->AddObjectRef(
|
||||
"dst_tensor", AccessType::WRITE,
|
||||
absl::make_unique<TensorDescriptor>(op_def.dst_tensors[0]));
|
||||
const auto src_tensor_type = op_def.src_tensors[0].storage_type;
|
||||
|
||||
const auto mode = GetFastestZeroMode(device);
|
||||
|
||||
const bool manual_clamp = src_tensor_type == TensorStorageType::BUFFER ||
|
||||
src_tensor_type == TensorStorageType::IMAGE_BUFFER;
|
||||
|
||||
std::string c = GetCommonDefines(op_def.precision);
|
||||
if (local_mem_uploads) {
|
||||
c += "__attribute__((reqd_work_group_size(8, 4, 1)))\n";
|
||||
}
|
||||
c += "__kernel void main_function(\n";
|
||||
c += src_tensor.GetDeclaration(AccessType::READ) + ",\n";
|
||||
if (weights_are_buffer) {
|
||||
c += " __global FLT4* filters\n";
|
||||
} else {
|
||||
c += " __read_only image2d_t filters\n";
|
||||
}
|
||||
c += GetArgsDeclaration(linked_operations);
|
||||
c += dst_tensor.GetDeclaration(AccessType::WRITE) + ",\n";
|
||||
c += " int4 dst_size\n";
|
||||
c += ") {\n";
|
||||
c += "$0) {\n";
|
||||
c += " int X = get_global_id(0) * 2;\n";
|
||||
c += " int Y = get_global_id(1) * 2;\n";
|
||||
c += " int Z = get_global_id(2);\n";
|
||||
c += " int S = get_global_id(2);\n";
|
||||
c += " ACCUM_FLT4 r0 = (ACCUM_FLT4)(0.0f);\n";
|
||||
c += " ACCUM_FLT4 r1 = (ACCUM_FLT4)(0.0f);\n";
|
||||
c += " ACCUM_FLT4 r2 = (ACCUM_FLT4)(0.0f);\n";
|
||||
c += " ACCUM_FLT4 r3 = (ACCUM_FLT4)(0.0f);\n";
|
||||
if (!local_mem_uploads) {
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) "
|
||||
"return;\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
|
||||
"|| S >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
}
|
||||
if (local_mem_uploads) {
|
||||
c += " __local FLT4 f[10];\n";
|
||||
c += " event_t e = async_work_group_copy(f, filters + Z * 10, 10, 0);\n";
|
||||
c += " event_t e = async_work_group_copy(f, args.weights.GetPtr() + S * "
|
||||
"10, 10, 0);\n";
|
||||
c += " wait_group_events(1, &e);\n";
|
||||
} else if (weights_are_buffer) {
|
||||
c += " __global FLT4* f = filters + Z * 10;\n";
|
||||
c += " __global FLT4* f = args.weights.GetPtr() + S * 10;\n";
|
||||
}
|
||||
c += " FLT4 s0;\n";
|
||||
c += " FLT4 s1;\n";
|
||||
@ -87,15 +79,15 @@ std::string GenerateDepthwiseConvCode(
|
||||
std::string xc[4] = {"X - 1", "X", "X + 1", "X + 2"};
|
||||
std::string yc[4] = {"Y - 1", "Y", "Y + 1", "Y + 2"};
|
||||
if (!weights_are_buffer) {
|
||||
c += " FLT4 f0 = READ_IMAGE(filters, smp_none, (int2)(0, Z));\n";
|
||||
c += " FLT4 f1 = READ_IMAGE(filters, smp_none, (int2)(1, Z));\n";
|
||||
c += " FLT4 f2 = READ_IMAGE(filters, smp_none, (int2)(2, Z));\n";
|
||||
c += " FLT4 f3 = READ_IMAGE(filters, smp_none, (int2)(3, Z));\n";
|
||||
c += " FLT4 f4 = READ_IMAGE(filters, smp_none, (int2)(4, Z));\n";
|
||||
c += " FLT4 f5 = READ_IMAGE(filters, smp_none, (int2)(5, Z));\n";
|
||||
c += " FLT4 f6 = READ_IMAGE(filters, smp_none, (int2)(6, Z));\n";
|
||||
c += " FLT4 f7 = READ_IMAGE(filters, smp_none, (int2)(7, Z));\n";
|
||||
c += " FLT4 f8 = READ_IMAGE(filters, smp_none, (int2)(8, Z));\n";
|
||||
c += " FLT4 f0 = args.weights.Read(0, S);\n";
|
||||
c += " FLT4 f1 = args.weights.Read(1, S);\n";
|
||||
c += " FLT4 f2 = args.weights.Read(2, S);\n";
|
||||
c += " FLT4 f3 = args.weights.Read(3, S);\n";
|
||||
c += " FLT4 f4 = args.weights.Read(4, S);\n";
|
||||
c += " FLT4 f5 = args.weights.Read(5, S);\n";
|
||||
c += " FLT4 f6 = args.weights.Read(6, S);\n";
|
||||
c += " FLT4 f7 = args.weights.Read(7, S);\n";
|
||||
c += " FLT4 f8 = args.weights.Read(8, S);\n";
|
||||
}
|
||||
if (manual_clamp) {
|
||||
c += " int x0 = X - 1;\n";
|
||||
@ -106,25 +98,25 @@ std::string GenerateDepthwiseConvCode(
|
||||
c += " int y1 = Y;\n";
|
||||
c += " int y2 = Y + 1;\n";
|
||||
c += " int y3 = Y + 2;\n";
|
||||
c += " bool x0_in = x0 >= 0 && x0 < dst_size.x;\n";
|
||||
c += " bool x1_in = x1 >= 0 && x1 < dst_size.x;\n";
|
||||
c += " bool x2_in = x2 >= 0 && x2 < dst_size.x;\n";
|
||||
c += " bool x3_in = x3 >= 0 && x3 < dst_size.x;\n";
|
||||
c += " bool y0_in = y0 >= 0 && y0 < dst_size.y;\n";
|
||||
c += " bool y1_in = y1 >= 0 && y1 < dst_size.y;\n";
|
||||
c += " bool y2_in = y2 >= 0 && y2 < dst_size.y;\n";
|
||||
c += " bool y3_in = y3 >= 0 && y3 < dst_size.y;\n";
|
||||
c += " x0 = clamp(x0, 0, dst_size.x - 1);\n";
|
||||
c += " x1 = clamp(x1, 0, dst_size.x - 1);\n";
|
||||
c += " x2 = clamp(x2, 0, dst_size.x - 1);\n";
|
||||
c += " x3 = clamp(x3, 0, dst_size.x - 1);\n";
|
||||
c += " y0 = clamp(y0, 0, dst_size.y - 1);\n";
|
||||
c += " y1 = clamp(y1, 0, dst_size.y - 1);\n";
|
||||
c += " y2 = clamp(y2, 0, dst_size.y - 1);\n";
|
||||
c += " y3 = clamp(y3, 0, dst_size.y - 1);\n";
|
||||
c += " bool x0_in = x0 >= 0 && x0 < args.dst_tensor.Width();\n";
|
||||
c += " bool x1_in = x1 >= 0 && x1 < args.dst_tensor.Width();\n";
|
||||
c += " bool x2_in = x2 >= 0 && x2 < args.dst_tensor.Width();\n";
|
||||
c += " bool x3_in = x3 >= 0 && x3 < args.dst_tensor.Width();\n";
|
||||
c += " bool y0_in = y0 >= 0 && y0 < args.dst_tensor.Height();\n";
|
||||
c += " bool y1_in = y1 >= 0 && y1 < args.dst_tensor.Height();\n";
|
||||
c += " bool y2_in = y2 >= 0 && y2 < args.dst_tensor.Height();\n";
|
||||
c += " bool y3_in = y3 >= 0 && y3 < args.dst_tensor.Height();\n";
|
||||
c += " x0 = clamp(x0, 0, args.dst_tensor.Width() - 1);\n";
|
||||
c += " x1 = clamp(x1, 0, args.dst_tensor.Width() - 1);\n";
|
||||
c += " x2 = clamp(x2, 0, args.dst_tensor.Width() - 1);\n";
|
||||
c += " x3 = clamp(x3, 0, args.dst_tensor.Width() - 1);\n";
|
||||
c += " y0 = clamp(y0, 0, args.dst_tensor.Height() - 1);\n";
|
||||
c += " y1 = clamp(y1, 0, args.dst_tensor.Height() - 1);\n";
|
||||
c += " y2 = clamp(y2, 0, args.dst_tensor.Height() - 1);\n";
|
||||
c += " y3 = clamp(y3, 0, args.dst_tensor.Height() - 1);\n";
|
||||
if (src_tensor_type == TensorStorageType::BUFFER) {
|
||||
c += " __global FLT4* src_loc = src_data + Z * dst_size.x * "
|
||||
"dst_size.y;\n";
|
||||
c += " __global FLT4* src_loc = "
|
||||
"args.src_tensor.GetPtrWithSliceOffset(S);\n";
|
||||
}
|
||||
xc[0] = "x0";
|
||||
xc[1] = "x1";
|
||||
@ -150,29 +142,29 @@ std::string GenerateDepthwiseConvCode(
|
||||
auto read_4x_line = [&](int y) {
|
||||
if (src_tensor_type == TensorStorageType::BUFFER) {
|
||||
const std::string y_in = "y" + std::to_string(y) + "_in";
|
||||
c += " s0 = src_loc[" + yc[y] + " * dst_size.x + " + xc[0] +
|
||||
"] * (FLT)(x0_in && " + y_in + ");\n";
|
||||
c += " s1 = src_loc[" + yc[y] + " * dst_size.x + " + xc[1] +
|
||||
"] * (FLT)(x1_in && " + y_in + ");\n";
|
||||
c += " s2 = src_loc[" + yc[y] + " * dst_size.x + " + xc[2] +
|
||||
"] * (FLT)(x2_in && " + y_in + ");\n";
|
||||
c += " s3 = src_loc[" + yc[y] + " * dst_size.x + " + xc[3] +
|
||||
"] * (FLT)(x3_in && " + y_in + ");\n";
|
||||
c += " s0 = src_loc[args.src_tensor.GetWHOffset(" + xc[0] + ", " +
|
||||
yc[y] + ")] * (FLT)(x0_in && " + y_in + ");\n";
|
||||
c += " s1 = src_loc[args.src_tensor.GetWHOffset(" + xc[1] + ", " +
|
||||
yc[y] + ")] * (FLT)(x1_in && " + y_in + ");\n";
|
||||
c += " s2 = src_loc[args.src_tensor.GetWHOffset(" + xc[2] + ", " +
|
||||
yc[y] + ")] * (FLT)(x2_in && " + y_in + ");\n";
|
||||
c += " s3 = src_loc[args.src_tensor.GetWHOffset(" + xc[3] + ", " +
|
||||
yc[y] + ")] * (FLT)(x3_in && " + y_in + ");\n";
|
||||
} else if (src_tensor_type == TensorStorageType::IMAGE_BUFFER) {
|
||||
const std::string y_in = "y" + std::to_string(y) + "_in";
|
||||
c += " s0 = " + src_tensor.ReadWHS(xc[0], yc[y], "Z", mode) +
|
||||
" * (FLT)(x0_in && " + y_in + ");\n";
|
||||
c += " s1 = " + src_tensor.ReadWHS(xc[1], yc[y], "Z", mode) +
|
||||
" * (FLT)(x1_in && " + y_in + ");\n";
|
||||
c += " s2 = " + src_tensor.ReadWHS(xc[2], yc[y], "Z", mode) +
|
||||
" * (FLT)(x2_in && " + y_in + ");\n";
|
||||
c += " s3 = " + src_tensor.ReadWHS(xc[3], yc[y], "Z", mode) +
|
||||
" * (FLT)(x3_in && " + y_in + ");\n";
|
||||
c += " s0 = args.src_tensor.Read(" + xc[0] + ", " + yc[y] +
|
||||
", S) * (FLT)(x0_in && " + y_in + ");\n";
|
||||
c += " s1 = args.src_tensor.Read(" + xc[1] + ", " + yc[y] +
|
||||
", S) * (FLT)(x1_in && " + y_in + ");\n";
|
||||
c += " s2 = args.src_tensor.Read(" + xc[2] + ", " + yc[y] +
|
||||
", S) * (FLT)(x2_in && " + y_in + ");\n";
|
||||
c += " s3 = args.src_tensor.Read(" + xc[3] + ", " + yc[y] +
|
||||
", S) * (FLT)(x3_in && " + y_in + ");\n";
|
||||
} else {
|
||||
c += " s0 = " + src_tensor.ReadWHS(xc[0], yc[y], "Z", mode) + ";\n";
|
||||
c += " s1 = " + src_tensor.ReadWHS(xc[1], yc[y], "Z", mode) + ";\n";
|
||||
c += " s2 = " + src_tensor.ReadWHS(xc[2], yc[y], "Z", mode) + ";\n";
|
||||
c += " s3 = " + src_tensor.ReadWHS(xc[3], yc[y], "Z", mode) + ";\n";
|
||||
c += " s0 = args.src_tensor.Read(" + xc[0] + ", " + yc[y] + ", S);\n";
|
||||
c += " s1 = args.src_tensor.Read(" + xc[1] + ", " + yc[y] + ", S);\n";
|
||||
c += " s2 = args.src_tensor.Read(" + xc[2] + ", " + yc[y] + ", S);\n";
|
||||
c += " s3 = args.src_tensor.Read(" + xc[3] + ", " + yc[y] + ", S);\n";
|
||||
}
|
||||
};
|
||||
c += " {\n";
|
||||
@ -224,40 +216,38 @@ std::string GenerateDepthwiseConvCode(
|
||||
c += " r3 += TO_ACCUM_TYPE(" + W[8] + " * s3);\n";
|
||||
c += " }\n";
|
||||
if (!weights_are_buffer) {
|
||||
c += " FLT4 bias = READ_IMAGE(filters, smp_none, (int2)(9, Z));\n";
|
||||
c += " FLT4 bias = args.weights.Read(9, S);\n";
|
||||
}
|
||||
c += " r0 += TO_ACCUM_TYPE(" + bias + ");\n";
|
||||
c += " r1 += TO_ACCUM_TYPE(" + bias + ");\n";
|
||||
c += " r2 += TO_ACCUM_TYPE(" + bias + ");\n";
|
||||
c += " r3 += TO_ACCUM_TYPE(" + bias + ");\n";
|
||||
if (local_mem_uploads) {
|
||||
c += " if (X >= dst_size.x || Y >= dst_size.y || Z >= dst_size.z) "
|
||||
"return;\n";
|
||||
c += " if (X >= args.dst_tensor.Width() || Y >= args.dst_tensor.Height() "
|
||||
"|| "
|
||||
"S >= args.dst_tensor.Slices()) { \n";
|
||||
c += " return; \n";
|
||||
c += " } \n";
|
||||
}
|
||||
c += " if(X + 0 < dst_size.x && Y + 0 < dst_size.y) {\n";
|
||||
c += " if(X + 0 < args.dst_tensor.Width() && Y + 0 < "
|
||||
"args.dst_tensor.Height()) {\n";
|
||||
c += " FLT4 result = TO_FLT4(r0);\n";
|
||||
c += " " + dst_tensor.GetAddressWHS("address", "X + 0", "Y + 0", "Z") + "\n";
|
||||
LinkingContext context{"result", "X + 0", "Y + 0", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("result", "X + 0", "Y + 0", "Z") + "\n";
|
||||
c += " args.dst_tensor.Write(result, X + 0, Y + 0, S)\n";
|
||||
c += " }\n";
|
||||
c += " if(X + 1 < dst_size.x && Y + 0 < dst_size.y) {\n";
|
||||
c += " if(X + 1 < args.dst_tensor.Width() && Y + 0 < "
|
||||
"args.dst_tensor.Height()) {\n";
|
||||
c += " FLT4 result = TO_FLT4(r1);\n";
|
||||
context = {"result", "X + 1", "Y + 0", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("result", "X + 1", "Y + 0", "Z") + "\n";
|
||||
c += " args.dst_tensor.Write(result, X + 1, Y + 0, S)\n";
|
||||
c += " }\n";
|
||||
c += " if(X + 0 < dst_size.x && Y + 1 < dst_size.y) {\n";
|
||||
c += " if(X + 0 < args.dst_tensor.Width() && Y + 1 < "
|
||||
"args.dst_tensor.Height()) {\n";
|
||||
c += " FLT4 result = TO_FLT4(r2);\n";
|
||||
context = {"result", "X + 0", "Y + 1", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("result", "X + 0", "Y + 1", "Z") + "\n";
|
||||
c += " args.dst_tensor.Write(result, X + 0, Y + 1, S)\n";
|
||||
c += " }\n";
|
||||
c += " if(X + 1 < dst_size.x && Y + 1 < dst_size.y) {\n";
|
||||
c += " if(X + 1 < args.dst_tensor.Width() && Y + 1 < "
|
||||
"args.dst_tensor.Height()) {\n";
|
||||
c += " FLT4 result = TO_FLT4(r3);\n";
|
||||
context = {"result", "X + 1", "Y + 1", "Z"};
|
||||
c += PostProcess(linked_operations, context);
|
||||
c += " " + dst_tensor.WriteWHS("result", "X + 1", "Y + 1", "Z") + "\n";
|
||||
c += " args.dst_tensor.Write(result, X + 1, Y + 1, S)\n";
|
||||
c += " }\n";
|
||||
c += "}\n";
|
||||
|
||||
@ -277,9 +267,6 @@ DepthwiseConv3x3::DepthwiseConv3x3(DepthwiseConv3x3&& operation)
|
||||
: GPUOperation(std::move(operation)),
|
||||
weights_are_buffer_(operation.weights_are_buffer_),
|
||||
local_mem_uploads_(operation.local_mem_uploads_),
|
||||
weights_tex2d_(std::move(operation.weights_tex2d_)),
|
||||
weights_buf_(std::move(operation.weights_buf_)),
|
||||
weights_(operation.weights_),
|
||||
kernel_(std::move(operation.kernel_)),
|
||||
work_group_size_(operation.work_group_size_) {}
|
||||
|
||||
@ -287,9 +274,6 @@ DepthwiseConv3x3& DepthwiseConv3x3::operator=(DepthwiseConv3x3&& operation) {
|
||||
if (this != &operation) {
|
||||
std::swap(weights_are_buffer_, operation.weights_are_buffer_);
|
||||
std::swap(local_mem_uploads_, operation.local_mem_uploads_);
|
||||
weights_tex2d_ = std::move(operation.weights_tex2d_);
|
||||
weights_buf_ = std::move(operation.weights_buf_);
|
||||
std::swap(weights_, operation.weights_);
|
||||
kernel_ = std::move(operation.kernel_);
|
||||
std::swap(work_group_size_, operation.work_group_size_);
|
||||
GPUOperation::operator=(std::move(operation));
|
||||
@ -300,8 +284,15 @@ DepthwiseConv3x3& DepthwiseConv3x3::operator=(DepthwiseConv3x3&& operation) {
|
||||
absl::Status DepthwiseConv3x3::Compile(
|
||||
const CreationContext& creation_context) {
|
||||
std::string code = GenerateDepthwiseConvCode(
|
||||
definition_, linked_operations_, *creation_context.device,
|
||||
weights_are_buffer_, local_mem_uploads_);
|
||||
definition_, *creation_context.device, weights_are_buffer_,
|
||||
local_mem_uploads_, &args_);
|
||||
std::string element_wise_code;
|
||||
RETURN_IF_ERROR(
|
||||
MergeOperations(linked_operations_, &args_, &element_wise_code));
|
||||
RETURN_IF_ERROR(args_.TransformToCLCode(creation_context.device->GetInfo(),
|
||||
{{"dst_tensor", element_wise_code}},
|
||||
&code));
|
||||
|
||||
std::vector<CompilerOptions> options;
|
||||
if (definition_.precision == CalculationsPrecision::F16 &&
|
||||
creation_context.device->IsPowerVR()) {
|
||||
@ -313,13 +304,10 @@ absl::Status DepthwiseConv3x3::Compile(
|
||||
}
|
||||
|
||||
absl::Status DepthwiseConv3x3::BindArguments() {
|
||||
kernel_.ResetBindingCounter();
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(src_[0]->GetMemoryPtr()));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(weights_));
|
||||
RETURN_IF_ERROR(BindArgs(&kernel_, linked_operations_));
|
||||
RETURN_IF_ERROR(kernel_.SetMemoryAuto(dst_[0]->GetMemoryPtrForWriting()));
|
||||
RETURN_IF_ERROR(kernel_.SetBytesAuto(dst_[0]->GetWHSB()));
|
||||
return absl::OkStatus();
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("src_tensor", src_[0]));
|
||||
RETURN_IF_ERROR(args_.SetObjectRef("dst_tensor", dst_[0]));
|
||||
RETURN_IF_ERROR(SetArguments(linked_operations_, &args_));
|
||||
return args_.Bind(kernel_.kernel());
|
||||
}
|
||||
|
||||
int3 DepthwiseConv3x3::GetGridSize() const {
|
||||
|
@ -71,9 +71,6 @@ class DepthwiseConv3x3 : public GPUOperation {
|
||||
|
||||
bool weights_are_buffer_;
|
||||
bool local_mem_uploads_;
|
||||
Texture2D weights_tex2d_;
|
||||
Buffer weights_buf_;
|
||||
cl_mem weights_;
|
||||
|
||||
CLKernel kernel_;
|
||||
int3 work_group_size_ = int3(8, 4, 1);
|
||||
@ -90,17 +87,19 @@ absl::Status DepthwiseConv3x3::UploadWeightsAndBiases(
|
||||
const bool fp32_weights = definition_.precision == CalculationsPrecision::F32;
|
||||
const int float4_size = fp32_weights ? 16 : 8;
|
||||
|
||||
Texture2D weights_tex2d;
|
||||
Buffer weights_buf;
|
||||
if (fp32_weights) {
|
||||
std::vector<float4> gpu_data(elements_count);
|
||||
RearrangeWeightsAndBiasesData(weights, biases, absl::MakeSpan(gpu_data));
|
||||
if (weights_are_buffer_) {
|
||||
RETURN_IF_ERROR(CreateReadOnlyBuffer(float4_size * elements_count,
|
||||
gpu_data.data(), context,
|
||||
&weights_buf_));
|
||||
&weights_buf));
|
||||
} else {
|
||||
RETURN_IF_ERROR(CreateTexture2DRGBA(
|
||||
definition_.GetDataType(), texture_width, texture_height,
|
||||
gpu_data.data(), context, &weights_tex2d_));
|
||||
gpu_data.data(), context, &weights_tex2d));
|
||||
}
|
||||
} else {
|
||||
std::vector<half4> gpu_data(elements_count);
|
||||
@ -108,18 +107,27 @@ absl::Status DepthwiseConv3x3::UploadWeightsAndBiases(
|
||||
if (weights_are_buffer_) {
|
||||
RETURN_IF_ERROR(CreateReadOnlyBuffer(float4_size * elements_count,
|
||||
gpu_data.data(), context,
|
||||
&weights_buf_));
|
||||
&weights_buf));
|
||||
} else {
|
||||
RETURN_IF_ERROR(CreateTexture2DRGBA(
|
||||
definition_.GetDataType(), texture_width, texture_height,
|
||||
gpu_data.data(), context, &weights_tex2d_));
|
||||
gpu_data.data(), context, &weights_tex2d));
|
||||
}
|
||||
}
|
||||
|
||||
if (weights_are_buffer_) {
|
||||
weights_ = weights_buf_.GetMemoryPtr();
|
||||
BufferDescriptor desc;
|
||||
desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
|
||||
desc.element_size = 4;
|
||||
args_.AddObject("weights", AccessType::READ,
|
||||
absl::make_unique<Buffer>(std::move(weights_buf)),
|
||||
absl::make_unique<BufferDescriptor>(desc));
|
||||
} else {
|
||||
weights_ = weights_tex2d_.GetMemoryPtr();
|
||||
Texture2DDescriptor desc;
|
||||
desc.element_type = fp32_weights ? DataType::FLOAT32 : DataType::FLOAT16;
|
||||
args_.AddObject("weights", AccessType::READ,
|
||||
absl::make_unique<Texture2D>(std::move(weights_tex2d)),
|
||||
absl::make_unique<Texture2DDescriptor>(desc));
|
||||
}
|
||||
|
||||
return absl::OkStatus();
|
||||
|
@ -172,6 +172,10 @@ absl::Status TensorDescriptor::PerformSelector(
|
||||
return PerformWriteLinearSelector(args, result);
|
||||
} else if (selector == "GetAddress") {
|
||||
return PerformGetAddressSelector(args, result);
|
||||
} else if (selector == "GetPtrWithSliceOffset") {
|
||||
return PerformGetPtrWithSliceOffsetSelector(args, result);
|
||||
} else if (selector == "GetWHOffset") {
|
||||
return PerformGetWHOffsetSelector(args, result);
|
||||
} else {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"TensorDescriptor don't have selector with name - ", selector));
|
||||
@ -351,6 +355,43 @@ absl::Status TensorDescriptor::PerformGetAddressSelector(
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformGetPtrWithSliceOffsetSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
if (storage_type != TensorStorageType::BUFFER) {
|
||||
return absl::InvalidArgumentError(
|
||||
"GetPtrWithSliceOffset selector can be used only with BUFFER");
|
||||
}
|
||||
if (args.size() != 1) {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"GetPtrWithSliceOffset require one argument(slice coordinate), but ",
|
||||
args.size(), " was passed"));
|
||||
}
|
||||
const std::string width = IsBatchedWidth() ? "width_batched" : "width";
|
||||
if (HasAxis(Axis::DEPTH)) {
|
||||
*result =
|
||||
absl::StrCat("buffer + ", args[0], " * ", width, " * height * depth");
|
||||
} else {
|
||||
*result = absl::StrCat("buffer + ", args[0], " * ", width, " * height");
|
||||
}
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
absl::Status TensorDescriptor::PerformGetWHOffsetSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const {
|
||||
if (storage_type != TensorStorageType::BUFFER) {
|
||||
return absl::InvalidArgumentError(
|
||||
"GetWHOffset selector can be used only with BUFFER");
|
||||
}
|
||||
if (args.size() != 2) {
|
||||
return absl::NotFoundError(absl::StrCat(
|
||||
"GetWHOffset require two arguments(X and Y coordinates), but ",
|
||||
args.size(), " was passed"));
|
||||
}
|
||||
const std::string width = IsBatchedWidth() ? "width_batched" : "width";
|
||||
*result = absl::StrCat(args[1], " * ", width, " + ", args[0]);
|
||||
return absl::OkStatus();
|
||||
}
|
||||
|
||||
std::string TensorDescriptor::DeclareAddress(const std::string& var_name,
|
||||
const std::string& address) const {
|
||||
return absl::StrCat(StorageTypeToAddressType(), " ", var_name, " = ", address,
|
||||
|
@ -85,6 +85,12 @@ struct TensorDescriptor : public GPUObjectDescriptor {
|
||||
absl::Status PerformGetAddressSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
absl::Status PerformGetPtrWithSliceOffsetSelector(
|
||||
const std::vector<std::string>& args, std::string* result) const;
|
||||
|
||||
absl::Status PerformGetWHOffsetSelector(const std::vector<std::string>& args,
|
||||
std::string* result) const;
|
||||
|
||||
std::string DeclareAddress(const std::string& var_name,
|
||||
const std::string& address) const;
|
||||
|
||||
|
@ -180,6 +180,23 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "prelu_tester",
|
||||
testonly = 1,
|
||||
srcs = ["prelu_tester.cc"],
|
||||
hdrs = ["prelu_tester.h"],
|
||||
deps = [
|
||||
"//tensorflow/lite:framework",
|
||||
"//tensorflow/lite:schema_fbs_version",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/kernels:builtin_ops",
|
||||
"//tensorflow/lite/schema:schema_fbs",
|
||||
"@FP16",
|
||||
"@com_google_googletest//:gtest",
|
||||
"@flatbuffers",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "reduce_tester",
|
||||
testonly = 1,
|
||||
@ -527,6 +544,21 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "prelu_test",
|
||||
srcs = ["prelu_test.cc"],
|
||||
linkopts = select({
|
||||
"//tensorflow:emscripten": EMSCRIPTEN_LINKOPTS,
|
||||
"//conditions:default": [],
|
||||
}),
|
||||
deps = [
|
||||
":prelu_tester",
|
||||
":test_main",
|
||||
":xnnpack_delegate_test_mode",
|
||||
"@com_google_googletest//:gtest",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "relu_test",
|
||||
srcs = ["relu_test.cc"],
|
||||
|
@ -21,7 +21,6 @@ limitations under the License.
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
583
tensorflow/lite/delegates/xnnpack/prelu_test.cc
Normal file
583
tensorflow/lite/delegates/xnnpack/prelu_test.cc
Normal file
@ -0,0 +1,583 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <memory>
|
||||
#include <random>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/delegates/xnnpack/prelu_tester.h"
|
||||
#include "tensorflow/lite/delegates/xnnpack/xnnpack_delegate.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy4D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({batch, height, width, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, 4DBy4DBroadcastChannels) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({1, 1, 1, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy4DBroadcastWidth) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({1, 1, width, 1})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy4DBroadcastHeight) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({1, height, 1, 1})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy4DBroadcastBatch) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({batch, 1, 1, 1})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy4DBroadcastHeightWidthChannels) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({1, height, width, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy3D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({height, width, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy2D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({width, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, 4DBy1D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_4DBy0D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_3DBy3D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({batch, width, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, 3DBy3DBroadcastChannels) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({1, 1, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_3DBy3DBroadcastWidth) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({1, width, 1})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_3DBy3DBroadcastBatch) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({batch, 1, 1})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_3DBy3DBroadcastWidthChannels) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({1, width, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_3DBy2D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({width, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, 3DBy1D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_3DBy0D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, width, channels})
|
||||
.SlopeShape({})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_2DBy2D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, channels})
|
||||
.SlopeShape({batch, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, 2DBy2DBroadcastChannels) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, channels})
|
||||
.SlopeShape({1, channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_2DBy2DBroadcastBatch) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, channels})
|
||||
.SlopeShape({batch, 1})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, 2DBy1D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, channels})
|
||||
.SlopeShape({channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_2DBy0D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, channels})
|
||||
.SlopeShape({})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, 1DBy1D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
|
||||
PreluTester().InputShape({batch}).SlopeShape({batch}).Test(
|
||||
xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
// TODO(b/159727692)
|
||||
TEST(Prelu, DISABLED_1DBy0D) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
|
||||
PreluTester().InputShape({batch}).SlopeShape({}).Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, FP16Weights) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({channels})
|
||||
.FP16Weights()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, SparseWeights) {
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(nullptr),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({channels})
|
||||
.SparseWeights()
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
TEST(Prelu, MultiThreading) {
|
||||
TfLiteXNNPackDelegateOptions delegate_options =
|
||||
TfLiteXNNPackDelegateOptionsDefault();
|
||||
delegate_options.num_threads = 2;
|
||||
std::unique_ptr<TfLiteDelegate, decltype(&TfLiteXNNPackDelegateDelete)>
|
||||
xnnpack_delegate(TfLiteXNNPackDelegateCreate(&delegate_options),
|
||||
TfLiteXNNPackDelegateDelete);
|
||||
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto shape_rng =
|
||||
std::bind(std::uniform_int_distribution<int32_t>(2, 5), std::ref(rng));
|
||||
const auto batch = shape_rng();
|
||||
const auto height = shape_rng();
|
||||
const auto width = shape_rng();
|
||||
const auto channels = shape_rng();
|
||||
|
||||
PreluTester()
|
||||
.InputShape({batch, height, width, channels})
|
||||
.SlopeShape({channels})
|
||||
.Test(xnnpack_delegate.get());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
237
tensorflow/lite/delegates/xnnpack/prelu_tester.cc
Normal file
237
tensorflow/lite/delegates/xnnpack/prelu_tester.cc
Normal file
@ -0,0 +1,237 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/lite/delegates/xnnpack/prelu_tester.h"
|
||||
|
||||
#include <array>
|
||||
#include <cstdint>
|
||||
#include <functional>
|
||||
#include <numeric>
|
||||
#include <random>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include <fp16.h>
|
||||
#include "flatbuffers/flatbuffers.h" // from @flatbuffers
|
||||
#include "tensorflow/lite/interpreter.h"
|
||||
#include "tensorflow/lite/kernels/register.h"
|
||||
#include "tensorflow/lite/model.h"
|
||||
#include "tensorflow/lite/schema/schema_generated.h"
|
||||
#include "tensorflow/lite/version.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
void PreluTester::Test(TfLiteDelegate* delegate) const {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto input_rng = std::bind(std::uniform_real_distribution<float>(-1.0f, 1.0f),
|
||||
std::ref(rng));
|
||||
|
||||
std::vector<char> buffer = CreateTfLiteModel();
|
||||
const Model* model = GetModel(buffer.data());
|
||||
|
||||
std::unique_ptr<Interpreter> delegate_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&delegate_interpreter),
|
||||
kTfLiteOk);
|
||||
std::unique_ptr<Interpreter> default_interpreter;
|
||||
ASSERT_EQ(
|
||||
InterpreterBuilder(model, ::tflite::ops::builtin::BuiltinOpResolver())(
|
||||
&default_interpreter),
|
||||
kTfLiteOk);
|
||||
|
||||
ASSERT_TRUE(delegate_interpreter);
|
||||
ASSERT_TRUE(default_interpreter);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->inputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->inputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->outputs().size(), 1);
|
||||
ASSERT_EQ(default_interpreter->outputs().size(), 1);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
ASSERT_EQ(default_interpreter->AllocateTensors(), kTfLiteOk);
|
||||
|
||||
ASSERT_EQ(delegate_interpreter->ModifyGraphWithDelegate(delegate), kTfLiteOk);
|
||||
|
||||
float* default_input_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->inputs()[0]);
|
||||
std::generate(default_input_data,
|
||||
default_input_data + ComputeSize(InputShape()),
|
||||
std::ref(input_rng));
|
||||
|
||||
float* xnnpack_input_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->inputs()[0]);
|
||||
std::copy(default_input_data, default_input_data + ComputeSize(InputShape()),
|
||||
xnnpack_input_data);
|
||||
|
||||
ASSERT_EQ(default_interpreter->Invoke(), kTfLiteOk);
|
||||
ASSERT_EQ(delegate_interpreter->Invoke(), kTfLiteOk);
|
||||
|
||||
float* default_output_data = default_interpreter->typed_tensor<float>(
|
||||
default_interpreter->outputs()[0]);
|
||||
float* xnnpack_output_data = delegate_interpreter->typed_tensor<float>(
|
||||
delegate_interpreter->outputs()[0]);
|
||||
|
||||
for (size_t i = 0; i < ComputeSize(OutputShape()); i++) {
|
||||
ASSERT_EQ(default_output_data[i], xnnpack_output_data[i]);
|
||||
}
|
||||
}
|
||||
|
||||
std::vector<char> PreluTester::CreateTfLiteModel() const {
|
||||
std::random_device random_device;
|
||||
auto rng = std::mt19937(random_device());
|
||||
auto slope_rng = std::bind(std::uniform_real_distribution<float>(0.25f, 0.5f),
|
||||
std::ref(rng));
|
||||
|
||||
flatbuffers::FlatBufferBuilder builder;
|
||||
std::vector<flatbuffers::Offset<OperatorCode>> operator_codes{
|
||||
{CreateOperatorCode(builder, BuiltinOperator_PRELU)}};
|
||||
if (FP16Weights()) {
|
||||
operator_codes.emplace_back(
|
||||
CreateOperatorCode(builder, BuiltinOperator_DEQUANTIZE));
|
||||
} else if (SparseWeights()) {
|
||||
operator_codes.emplace_back(
|
||||
CreateOperatorCode(builder, BuiltinOperator_DENSIFY));
|
||||
}
|
||||
|
||||
std::vector<flatbuffers::Offset<Buffer>> buffers{{
|
||||
CreateBuffer(builder, builder.CreateVector({})),
|
||||
}};
|
||||
|
||||
if (FP16Weights()) {
|
||||
std::vector<uint16_t> slope_data(ComputeSize(SlopeShape()));
|
||||
std::generate(slope_data.begin(), slope_data.end(),
|
||||
std::bind(fp16_ieee_from_fp32_value, slope_rng));
|
||||
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(slope_data.data()),
|
||||
sizeof(uint16_t) * slope_data.size())));
|
||||
} else {
|
||||
std::vector<float> slope_data(ComputeSize(SlopeShape()));
|
||||
std::generate(slope_data.begin(), slope_data.end(), slope_rng);
|
||||
|
||||
buffers.push_back(CreateBuffer(
|
||||
builder, builder.CreateVector(
|
||||
reinterpret_cast<const uint8_t*>(slope_data.data()),
|
||||
sizeof(float) * slope_data.size())));
|
||||
}
|
||||
|
||||
std::vector<flatbuffers::Offset<Tensor>> tensors;
|
||||
std::vector<flatbuffers::Offset<Operator>> operators;
|
||||
if (FP16Weights()) {
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(SlopeShape().data(), SlopeShape().size()),
|
||||
TensorType_FLOAT16, /*buffer=*/1));
|
||||
} else if (SparseWeights()) {
|
||||
const int dims_count = SlopeShape().size();
|
||||
std::vector<flatbuffers::Offset<DimensionMetadata>> dim_metadata(
|
||||
dims_count);
|
||||
std::vector<int> traversal_order(dims_count);
|
||||
for (int i = 0; i < dims_count; i++) {
|
||||
traversal_order[i] = i;
|
||||
dim_metadata[i] = CreateDimensionMetadata(builder, DimensionType_DENSE,
|
||||
SlopeShape()[i]);
|
||||
}
|
||||
const flatbuffers::Offset<SparsityParameters> sparsity_param =
|
||||
CreateSparsityParameters(builder, builder.CreateVector(traversal_order),
|
||||
0, builder.CreateVector(dim_metadata));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(SlopeShape().data(), SlopeShape().size()),
|
||||
TensorType_FLOAT32, /*buffer=*/1, /*name=*/0, /*quantization=*/0,
|
||||
/*is_variable=*/false, /*sparsity=*/sparsity_param));
|
||||
}
|
||||
if (FP16Weights()) {
|
||||
const std::array<int32_t, 1> dequantize_inputs{{0}};
|
||||
const std::array<int32_t, 1> dequantize_outputs{{2}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(dequantize_inputs.data(),
|
||||
dequantize_inputs.size()),
|
||||
builder.CreateVector<int32_t>(dequantize_outputs.data(),
|
||||
dequantize_outputs.size())));
|
||||
} else if (SparseWeights()) {
|
||||
const std::array<int32_t, 1> densify_inputs{{0}};
|
||||
const std::array<int32_t, 1> densify_outputs{{2}};
|
||||
operators.emplace_back(
|
||||
CreateOperator(builder, /*opcode_index=*/1,
|
||||
builder.CreateVector<int32_t>(densify_inputs.data(),
|
||||
densify_inputs.size()),
|
||||
builder.CreateVector<int32_t>(densify_outputs.data(),
|
||||
densify_outputs.size())));
|
||||
}
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(InputShape().data(), InputShape().size()),
|
||||
TensorType_FLOAT32));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(SlopeShape().data(), SlopeShape().size()),
|
||||
TensorType_FLOAT32,
|
||||
/*buffer=*/(FP16Weights() || SparseWeights()) ? 0 : 1));
|
||||
tensors.emplace_back(CreateTensor(
|
||||
builder,
|
||||
builder.CreateVector<int32_t>(OutputShape().data(), OutputShape().size()),
|
||||
TensorType_FLOAT32));
|
||||
|
||||
const std::array<int32_t, 2> op_inputs{
|
||||
{static_cast<int>(tensors.size()) - 3,
|
||||
static_cast<int>(tensors.size()) - 2}};
|
||||
const std::array<int32_t, 1> op_outputs{
|
||||
{static_cast<int>(tensors.size()) - 1}};
|
||||
operators.emplace_back(CreateOperator(
|
||||
builder, /*opcode_index=*/0,
|
||||
builder.CreateVector<int32_t>(op_inputs.data(), op_inputs.size()),
|
||||
builder.CreateVector<int32_t>(op_outputs.data(), op_outputs.size())));
|
||||
|
||||
const std::array<int32_t, 1> subgraph_inputs{
|
||||
{static_cast<int32_t>(tensors.size() - 3)}};
|
||||
const std::array<int32_t, 1> subgraph_outputs{
|
||||
{static_cast<int32_t>(tensors.size()) - 1}};
|
||||
flatbuffers::Offset<SubGraph> subgraph = CreateSubGraph(
|
||||
builder, builder.CreateVector(tensors.data(), tensors.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_inputs.data(),
|
||||
subgraph_inputs.size()),
|
||||
builder.CreateVector<int32_t>(subgraph_outputs.data(),
|
||||
subgraph_outputs.size()),
|
||||
builder.CreateVector(operators.data(), operators.size()));
|
||||
|
||||
flatbuffers::Offset<flatbuffers::String> description =
|
||||
builder.CreateString("PReLU model");
|
||||
|
||||
flatbuffers::Offset<Model> model_buffer = CreateModel(
|
||||
builder, TFLITE_SCHEMA_VERSION,
|
||||
builder.CreateVector(operator_codes.data(), operator_codes.size()),
|
||||
builder.CreateVector(&subgraph, 1), description,
|
||||
builder.CreateVector(buffers.data(), buffers.size()));
|
||||
|
||||
builder.Finish(model_buffer);
|
||||
|
||||
return std::vector<char>(builder.GetBufferPointer(),
|
||||
builder.GetBufferPointer() + builder.GetSize());
|
||||
}
|
||||
|
||||
int32_t PreluTester::ComputeSize(const std::vector<int32_t>& shape) {
|
||||
return std::accumulate(shape.cbegin(), shape.cend(), 1,
|
||||
std::multiplies<int32_t>());
|
||||
}
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
88
tensorflow/lite/delegates/xnnpack/prelu_tester.h
Normal file
88
tensorflow/lite/delegates/xnnpack/prelu_tester.h
Normal file
@ -0,0 +1,88 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_LITE_DELEGATES_XNNPACK_PRELU_TESTER_H_
|
||||
#define TENSORFLOW_LITE_DELEGATES_XNNPACK_PRELU_TESTER_H_
|
||||
|
||||
#include <cstdint>
|
||||
#include <vector>
|
||||
|
||||
#include <gtest/gtest.h>
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace xnnpack {
|
||||
|
||||
class PreluTester {
|
||||
public:
|
||||
PreluTester() = default;
|
||||
PreluTester(const PreluTester&) = delete;
|
||||
PreluTester& operator=(const PreluTester&) = delete;
|
||||
|
||||
inline PreluTester& InputShape(std::initializer_list<int32_t> shape) {
|
||||
for (auto it = shape.begin(); it != shape.end(); ++it) {
|
||||
EXPECT_GT(*it, 0);
|
||||
}
|
||||
input_shape_ = std::vector<int32_t>(shape.begin(), shape.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline const std::vector<int32_t>& InputShape() const { return input_shape_; }
|
||||
|
||||
inline PreluTester& SlopeShape(std::initializer_list<int32_t> shape) {
|
||||
for (auto it = shape.begin(); it != shape.end(); ++it) {
|
||||
EXPECT_GT(*it, 0);
|
||||
}
|
||||
slope_shape_ = std::vector<int32_t>(shape.begin(), shape.end());
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline const std::vector<int32_t>& SlopeShape() const { return slope_shape_; }
|
||||
|
||||
inline const std::vector<int32_t>& OutputShape() const {
|
||||
return InputShape();
|
||||
}
|
||||
|
||||
inline PreluTester& FP16Weights() {
|
||||
fp16_weights_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline bool FP16Weights() const { return fp16_weights_; }
|
||||
|
||||
inline PreluTester& SparseWeights() {
|
||||
sparse_weights_ = true;
|
||||
return *this;
|
||||
}
|
||||
|
||||
inline bool SparseWeights() const { return sparse_weights_; }
|
||||
|
||||
void Test(TfLiteDelegate* delegate) const;
|
||||
|
||||
private:
|
||||
std::vector<char> CreateTfLiteModel() const;
|
||||
|
||||
static int32_t ComputeSize(const std::vector<int32_t>& shape);
|
||||
|
||||
std::vector<int32_t> input_shape_;
|
||||
std::vector<int32_t> slope_shape_;
|
||||
bool fp16_weights_ = false;
|
||||
bool sparse_weights_ = false;
|
||||
};
|
||||
|
||||
} // namespace xnnpack
|
||||
} // namespace tflite
|
||||
|
||||
#endif // TENSORFLOW_LITE_DELEGATES_XNNPACK_PRELU_TESTER_H_
|
@ -2266,7 +2266,8 @@ class Subgraph {
|
||||
const TfLiteTensor& input_tensor = tensors[node->inputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 4,
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, input_tensor, 1,
|
||||
XNN_MAX_TENSOR_DIMS,
|
||||
node->inputs->data[0]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, input_tensor, node->inputs->data[0], node_index));
|
||||
@ -2284,7 +2285,8 @@ class Subgraph {
|
||||
const TfLiteTensor& output_tensor = tensors[node->outputs->data[0]];
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorFloatType(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 4,
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorShape(logging_context, output_tensor, 1,
|
||||
XNN_MAX_TENSOR_DIMS,
|
||||
node->outputs->data[0]));
|
||||
TF_LITE_ENSURE_STATUS(CheckTensorNonDynamicAllocation(
|
||||
logging_context, output_tensor, node->outputs->data[0], node_index));
|
||||
|
@ -86,6 +86,8 @@ upper_tabs:
|
||||
path: /lite/convert/rnn
|
||||
- title: "Add metadata"
|
||||
path: /lite/convert/metadata
|
||||
- title: "Composite operation fusion"
|
||||
path: /lite/convert/operation_fusion
|
||||
- title: "1.x compatibility"
|
||||
path: /lite/convert/1x_compatibility
|
||||
|
||||
|
270
tensorflow/lite/g3doc/convert/operation_fusion.md
Normal file
270
tensorflow/lite/g3doc/convert/operation_fusion.md
Normal file
@ -0,0 +1,270 @@
|
||||
# TensorFlow operation fusion
|
||||
|
||||
## Overview
|
||||
|
||||
This page describes the design and steps needed to convert composite operations
|
||||
in TensorFlow to fused operations in TensorFlow Lite. This infrastructure is
|
||||
general purpose and supports conversion of any composite operation in TensorFlow
|
||||
to a corresponding fused operation in TensorFlow Lite.
|
||||
|
||||
An example use of this infrastructure is TensorFlow RNN operation fusion to
|
||||
TensorFlow Lite, as detailed
|
||||
[here](https://www.tensorflow.org/lite/convert/rnn).
|
||||
|
||||
### What are fused operations
|
||||
|
||||

|
||||
|
||||
TensorFlow operations can either be primitive ops e.g.
|
||||
[tf.add](https://www.tensorflow.org/api_docs/python/tf/math/add) or they can be
|
||||
composed from other primitive operations e.g.
|
||||
[tf.einsum](https://www.tensorflow.org/api_docs/python/tf/einsum). A primitive
|
||||
operation shows up as a single node in the TensorFlow graph while.a composite
|
||||
operation is a collection of nodes in the TensorFlow graph. Executing a
|
||||
composite operation is equivalent to executing each of its constituent primitive
|
||||
operations.
|
||||
|
||||
A fused operation corresponds to a single operation that subsumes all the
|
||||
computation performed by each primitive operation within the corresponding
|
||||
composite operation.
|
||||
|
||||
### Benefits of fused operations
|
||||
|
||||
Fused operations exist to maximize the performance of their underlying kernel
|
||||
implementations, by optimizing the overall computation and reducing memory
|
||||
footprint. This is very valuable, especially for low-latency inference workloads
|
||||
and resource constrained mobile platforms.
|
||||
|
||||
Fused operations also provide a higher level interface to define complex
|
||||
transformations like quantization, which would otherwise be infeasible or very
|
||||
hard to do at a more granular level.
|
||||
|
||||
TensorFlow Lite has many instances of fused operations for the reasons
|
||||
articulated above. These fused operations typically correspond to composite
|
||||
operations in the source TensorFlow program. Examples of composite operations in
|
||||
TensorFlow that are implemented as a single fused operation in TensorFlow Lite
|
||||
include various RNN operations like Unidirectional and Bidirectional sequence
|
||||
LSTM, convolution (conv2d, bias add, relu), fully connected (matmul, bias add,
|
||||
relu) and more. In TensorFlow Lite, LSTM quantization is currently only
|
||||
implemented in the fused LSTM operations.
|
||||
|
||||
### Challenges with fused operations
|
||||
|
||||
Converting composite operations from TensorFlow to fused operations in
|
||||
TensorFlow Lite is a hard problem. This is because:
|
||||
|
||||
1. Composite operations are represented in the TensorFlow graph as a set of
|
||||
primitive operations without a well defined boundary. It can be very
|
||||
challenging to identify (e.g. via pattern matching) the sub-graph
|
||||
corresponding to such a composite operation.
|
||||
|
||||
1. There may be more than one TensorFlow implementation targeting a fused
|
||||
TensorFlow Lite operation. For example, there are many LSTM implementations
|
||||
in TensorFlow (Keras, Babelfish/lingvo etc) and each of these is composed of
|
||||
different primitive operations but they all could still be converted to the
|
||||
same fused LSTM operation in TensorFlow Lite.
|
||||
|
||||
As such, conversion of fused operations has proven quite challenging.
|
||||
|
||||
## Converting from composite to fused operation
|
||||
|
||||
The overall architecture for converting TensorFlow composite operations to
|
||||
TensorFlow Lite fused operations is below:
|
||||
|
||||

|
||||
|
||||
### Wrap the composite operation in a `tf.function`
|
||||
|
||||
In the TensorFlow model source code, identify and abstract out the composite
|
||||
operation into a `tf.function` with the
|
||||
[experimental\_implements](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/function.py#L88)
|
||||
function annotation. See an example of [embedding lookup](#composing_ops). The
|
||||
function defines the interface and its arguments should be used to implement the
|
||||
conversion logic.
|
||||
|
||||
### Write conversion code
|
||||
|
||||
The conversion code is written per the interface of the function with the
|
||||
`implements` annotation. See an example fusion for
|
||||
[embedding lookup](#fusion_code). Conceptually, the conversion code replaces the
|
||||
composite implementation of this interface with the fused one.
|
||||
|
||||
In the prepare-composite-functions pass, plugin in your
|
||||
[conversion code](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L108).
|
||||
|
||||
In more advanced usages, it is possible to implement complex transformations of
|
||||
the composite operation's operands in order to derive the operands of the fused
|
||||
operation. See
|
||||
[Keras LSTM](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L627).
|
||||
conversion code as an example.
|
||||
|
||||
### Convert to TensorFlow Lite
|
||||
|
||||
Use the
|
||||
[TFLiteConverter.from_saved_model](https://www.tensorflow.org/api_docs/python/tf/lite/TFLiteConverter#from_saved_model)
|
||||
API to convert to TensorFlow Lite.
|
||||
|
||||
## Under the hood
|
||||
|
||||
<a id="under_the_hood"></a>
|
||||
|
||||
We now describe high level details of the overall design in converting to fused
|
||||
operations in TensorFlow Lite.
|
||||
|
||||
### Composing operations in TensorFlow
|
||||
|
||||
<a id="composing_ops"></a>
|
||||
|
||||
The use of `tf.function` with the
|
||||
[experimental\_implements](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/eager/function.py#L88)
|
||||
function attribute allows users to explicitly compose new operations using
|
||||
TensorFlow primitive operations and specify the interface that the resultant
|
||||
composite operation implements. This is very useful as it provides:
|
||||
|
||||
1. A well-defined boundary for the composite operation in the underlying
|
||||
TensorFlow graph.
|
||||
1. Explicitly specify the interface that this operation implements. The
|
||||
arguments of the `tf.function` correspond to the arguments of this
|
||||
interface.
|
||||
|
||||
As an example, let’s consider a composite operation defined in
|
||||
[Lingvo/TensorFlow](https://github.com/tensorflow/lingvo) to implement embedding
|
||||
lookup. This maps to a fused operation in TensorFlow Lite.
|
||||
|
||||
```python
|
||||
@tf.function(
|
||||
experimental_implements="lingvo.embedding_lookup")
|
||||
def EmbFprop(embs, ids_vec):
|
||||
"""Embedding forward prop.
|
||||
|
||||
Effectively, it computes:
|
||||
num = size of ids_vec
|
||||
rets = zeros([num, embedding dim])
|
||||
for i in range(num):
|
||||
rets[i, :] = embs[ids_vec[i], :]
|
||||
return rets
|
||||
|
||||
Args:
|
||||
embs: The embedding matrix.
|
||||
ids_vec: A vector of int32 embedding ids.
|
||||
|
||||
Returns:
|
||||
The result of embedding lookups. A matrix of shape
|
||||
[num ids in ids_vec, embedding dims].
|
||||
"""
|
||||
num = tf.shape(ids_vec)[0]
|
||||
rets = inplace_ops.empty([num] + emb_shape_suf, py_utils.FPropDtype(p))
|
||||
|
||||
def EmbFpropLoop(i, embs, ids_vec, rets):
|
||||
# row_id = ids_vec[i]
|
||||
row_id = tf.gather(ids_vec, i)
|
||||
# row = embs[row_id]
|
||||
row = tf.reshape(tf.gather(embs, row_id), [1] + emb_shape_suf)
|
||||
# rets[i] = row
|
||||
rets = inplace_ops.alias_inplace_update(rets, [i], row)
|
||||
return embs, ids_vec, rets
|
||||
|
||||
_, _, rets = functional_ops.For(
|
||||
start=0,
|
||||
limit=num,
|
||||
delta=1,
|
||||
inputs=[embs, ids_vec, rets],
|
||||
body=EmbFpropLoop,
|
||||
rewrite_with_while=compiled)
|
||||
if len(weight_shape) > 2:
|
||||
rets = tf.reshape(rets, [num, symbolic.ToStatic(p.embedding_dim)])
|
||||
return rets
|
||||
```
|
||||
|
||||
By making models use composite operations via `tf.function` as illustrated
|
||||
above, it becomes possible to build a general infrastructure to **identify and
|
||||
convert** such operations to fused TensorFlow Lite operations.
|
||||
|
||||
### Extending the TensorFlow Lite converter
|
||||
|
||||
The TensorFlow Lite converter that was released earlier this year only supported
|
||||
importing TensorFlow models as a graph with all variables replaced with their
|
||||
corresponding constant values. This does not work for operation fusion since
|
||||
such graphs have all functions inlined so that the variables can be turned into
|
||||
constants.
|
||||
|
||||
In order to leverage the `tf.function` with the `experimental_implements`
|
||||
feature during the conversion process, the functions need to be preserved until
|
||||
later in the conversion process.
|
||||
|
||||
As such, we implemented a new workflow of importing and converting TensorFlow
|
||||
models in the converter to support the composite operation fusion use case.
|
||||
Specifically, the new features added are:
|
||||
|
||||
1. Importing TensorFlow
|
||||
[saved models into MLIR](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/translate/import_model.cc#L3593)
|
||||
1. [fuse composite operations](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L103)
|
||||
1. [variable mutability analysis](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/transforms/optimize_global_tensors.cc#L43)
|
||||
1. [freeze all read-only variables](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/tensorflow/transforms/freeze_global_tensors.cc#L44)
|
||||
|
||||
This allows us to perform operation fusion using the functions representing the
|
||||
composite operations prior to function inlining and variable freezing.
|
||||
|
||||
### Implementing operation fusion
|
||||
|
||||
Let’s look at the operation fusion pass in more detail. This pass does the
|
||||
following:
|
||||
|
||||
1. Loop through all functions in the MLIR module.
|
||||
1. If a function has the tf.\_implements attribute, based on the attribute
|
||||
value, calls the appropriate operation fusion utility.
|
||||
1. The operation fusion utility operates on the function’s operands and
|
||||
attributes (which serve as the interface for the conversion) and replaces
|
||||
the body of the function with an equivalent function body containing the
|
||||
fused operation.
|
||||
1. In many cases, the replaced body will contain operations other than the
|
||||
fused operation. These correspond to some static transforms on the
|
||||
function’s operands in order to obtain the operands of the fused operation.
|
||||
Since these computations can all be constant folded away, they would not be
|
||||
present in the exported flatbuffer where only the fused operation would
|
||||
exist.
|
||||
|
||||
Here is code snippet from the pass showing the main workflow:
|
||||
|
||||
```
|
||||
void PrepareCompositeFunctionsPass::ConvertTFImplements(FuncOp func,
|
||||
StringAttr attr) {
|
||||
if (attr.getValue() == "lingvo.embedding_lookup") {
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
// Convert the composite embedding_lookup function body to a
|
||||
// TFLite fused embedding_lookup op.
|
||||
ConvertEmbeddedLookupFunc convert_embedded_lookup(func);
|
||||
if (failed(convert_embedded_lookup.VerifySignature())) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
convert_embedded_lookup.RewriteFunc();
|
||||
} else if (attr.getValue() == mlir::TFL::kKerasLstm) {
|
||||
func.eraseBody();
|
||||
func.addEntryBlock();
|
||||
OpBuilder builder(func.getBody());
|
||||
if (failed(ConvertKerasLSTMLayer(func, &builder))) {
|
||||
return signalPassFailure();
|
||||
}
|
||||
} else if (.....) /* Other fusions can plug in here */
|
||||
}
|
||||
```
|
||||
|
||||
Here is code snippet showing mapping this composite operation to a fused
|
||||
operation in TensorFlow Lite leveraging the function as a conversion interface.
|
||||
|
||||
<a id="fusion_code"></a>
|
||||
|
||||
```C++
|
||||
void RewriteFunc() {
|
||||
Value lookup = func_.getArgument(1);
|
||||
Value value = func_.getArgument(0);
|
||||
auto output_type = func_.getType().getResult(0);
|
||||
|
||||
OpBuilder builder(func_.getBody());
|
||||
auto op = builder.create<mlir::TFL::EmbeddingLookupOp>(
|
||||
func_.getLoc(), output_type, lookup, value);
|
||||
|
||||
builder.create<mlir::ReturnOp>(func_.getLoc(), op.getResult());
|
||||
}
|
||||
```
|
@ -30,8 +30,8 @@ This document contains [example usages](#examples) of the API and
|
||||
### Converting a SavedModel <a name="saved_model"></a>
|
||||
|
||||
The following example shows how to convert a
|
||||
[SavedModel](https://www.tensorflow.org/guide/saved_model) into a
|
||||
TensorFlow Lite [`FlatBuffer`](https://google.github.io/flatbuffers/).
|
||||
[SavedModel](https://www.tensorflow.org/guide/saved_model) into a TensorFlow
|
||||
Lite [`FlatBuffer`](https://google.github.io/flatbuffers/).
|
||||
|
||||
```python
|
||||
import tensorflow as tf
|
||||
@ -97,6 +97,24 @@ with tf.io.gfile.GFile('model.tflite', 'wb') as f:
|
||||
f.write(tflite_model)
|
||||
```
|
||||
|
||||
If your model requires specifying the input shape, use `tf.keras.layers.Input`
|
||||
or `tf.keras.layers.InputLayer` to create a Keras model with a fixed input shape
|
||||
as seen below or use the [`from_concrete_functions`](#concrete_function)
|
||||
classmethod as shown in the prior section to set the shape of the input arrays
|
||||
prior to conversion.
|
||||
|
||||
```python
|
||||
input = tf.keras.layers.Input(shape=(1), batch_size=1)
|
||||
dense_layer = tf.keras.layers.Dense(units=1, input_shape=[1])
|
||||
model = tf.keras.Model(input, dense_layer(input))
|
||||
```
|
||||
|
||||
```python
|
||||
model = tf.keras.models.Sequential(
|
||||
[tf.keras.layers.InputLayer(input_shape=(1), batch_size=1),
|
||||
tf.keras.layers.Dense(units=1, input_shape=[1])])
|
||||
```
|
||||
|
||||
### Converting a concrete function <a name="concrete_function"></a>
|
||||
|
||||
The following example shows how to convert a TensorFlow
|
||||
|
@ -3,9 +3,9 @@
|
||||
## Overview
|
||||
|
||||
TensorFlow Lite supports converting TensorFlow RNN models to TensorFlow Lite’s
|
||||
fused LSTM operators. Fused operators exist to maximize the performance of their
|
||||
underlying kernel implementations, as well as provide a higher level interface
|
||||
to define complex transformations like quantizatization.
|
||||
fused LSTM operations. Fused operations exist to maximize the performance of
|
||||
their underlying kernel implementations, as well as provide a higher level
|
||||
interface to define complex transformations like quantizatization.
|
||||
|
||||
Since there are many variants of RNN APIs in TensorFlow, our approach has been
|
||||
two fold:
|
||||
@ -23,15 +23,16 @@ two fold:
|
||||
|
||||
## Converter API
|
||||
|
||||
Currently this feature is available through the
|
||||
[tf-nightly](https://pypi.org/project/tf-nightly/) pip or from head. This will
|
||||
be available in the TensorFlow 2.3 release.
|
||||
The feature is part of TensorFlow 2.3 release. It is also available through the
|
||||
[tf-nightly](https://pypi.org/project/tf-nightly/) pip or from head.
|
||||
|
||||
This conversion functionality is available when converting to TensorFlow Lite
|
||||
via a SavedModel or from the Keras model directly. See example usages.
|
||||
|
||||
### From saved model
|
||||
|
||||
<a id="from_saved_model"></a>
|
||||
|
||||
```
|
||||
# build a saved model. Here concrete_function is the exported function
|
||||
# corresponding to the TensorFlow model containing one or more
|
||||
@ -64,6 +65,8 @@ illustrates the end to end usage with the TensorFlow Lite interpreter.
|
||||
|
||||
## TensorFlow RNNs APIs supported
|
||||
|
||||
<a id="rnn_apis"></a>
|
||||
|
||||
### Keras LSTM conversion (recommended)
|
||||
|
||||
We support out-of-the-box conversion of Keras LSTM to TensorFlow Lite. For
|
||||
@ -75,13 +78,17 @@ details on how this works please refer to the
|
||||
Also important is to highlight the TensorFlow Lite’s LSTM contract with respect
|
||||
to the Keras operation definition:
|
||||
|
||||
1. The dimension 0 of the input tensor is the batch size.
|
||||
1. The dimension 0 of the recurrent\_weight tensor is the number of outputs.
|
||||
1. The dimension 0 of the **input** tensor is the batch size.
|
||||
1. The dimension 0 of the **recurrent\_weight** tensor is the number of
|
||||
outputs.
|
||||
1. The **weight** and **recurrent\_kernel** tensors are transposed.
|
||||
1. The transposed weight, transposed recurrent\_kernel and bias tensors are
|
||||
1. The transposed weight, transposed recurrent\_kernel and **bias** tensors are
|
||||
split into 4 equal sized tensors along the dimension 0. These correspond to
|
||||
**input gate, forget gate, cell, and output gate**.
|
||||
|
||||
See the detailed conversion code from Keras LSTM to TensorFlow Lite
|
||||
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/utils/lstm_utils.cc#L627).
|
||||
|
||||
#### Keras LSTM Variants
|
||||
|
||||
##### Time major
|
||||
@ -98,7 +105,7 @@ forward and one for backward, see examples
|
||||
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/keras/layers/wrappers.py#L381).
|
||||
Once we see the go\_backward attribute, we recognize it as backward LSTM, then
|
||||
we group forward & backward LSTM together. **This is future work.** Currently,
|
||||
this creates two UnidirectionalSequenceLSTM operators in the TensorFlow Lite
|
||||
this creates two UnidirectionalSequenceLSTM operations in the TensorFlow Lite
|
||||
model.
|
||||
|
||||
### User-defined LSTM conversion examples
|
||||
@ -134,7 +141,7 @@ MLIR-pass
|
||||
[here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/compiler/mlir/lite/transforms/prepare_composite_functions_tf.cc#L108).
|
||||
The function’s interface should be treated like an API contract and should
|
||||
contain the arguments needed to convert to fused TensorFlow Lite LSTM
|
||||
operators - i.e. input, bias, weights, projection, layer normalization, etc. It
|
||||
operations - i.e. input, bias, weights, projection, layer normalization, etc. It
|
||||
is preferable for the tensors passed as arguments to this function to have known
|
||||
rank (i.e. RankedTensorType in MLIR). This makes it much easier to write
|
||||
conversion code that can assume these tensors as RankedTensorType and helps
|
||||
@ -189,5 +196,5 @@ follows:
|
||||
the user program. Such a TensorFlow program can still be converted to
|
||||
TensorFlow Lite using the feature being described here.
|
||||
1. Bidirectional LSTM is currently modelled as two UnidirectionalSequenceLSTM
|
||||
operators in TensorFlow Lite. This will be replaced with a single
|
||||
operations in TensorFlow Lite. This will be replaced with a single
|
||||
BidirectionalSequenceLSTM op.
|
||||
|
@ -1147,11 +1147,8 @@ models:
|
||||
* `CALL`
|
||||
* `CONCAT_EMBEDDINGS`
|
||||
* `CUSTOM`
|
||||
* `EMBEDDING_LOOKUP`
|
||||
* `EMBEDDING_LOOKUP_SPARSE`
|
||||
* `HASHTABLE_LOOKUP`
|
||||
* `LSH_PROJECTION`
|
||||
* `LSTM`
|
||||
* `RNN`
|
||||
* `SKIP_GRAM`
|
||||
* `SVDF`
|
||||
|
BIN
tensorflow/lite/g3doc/images/convert/op_fusion.png
Normal file
BIN
tensorflow/lite/g3doc/images/convert/op_fusion.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 39 KiB |
BIN
tensorflow/lite/g3doc/images/convert/op_fusion_banner.jpg
Normal file
BIN
tensorflow/lite/g3doc/images/convert/op_fusion_banner.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 88 KiB |
@ -12,7 +12,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"colab": {},
|
||||
@ -49,7 +49,7 @@
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "nDABAblytltI"
|
||||
},
|
||||
},
|
||||
"source": [
|
||||
"\u003ctable class=\"tfo-notebook-buttons\" align=\"left\"\u003e\n",
|
||||
" \u003ctd\u003e\n",
|
||||
@ -93,7 +93,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -116,7 +116,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -131,6 +131,7 @@
|
||||
"\n",
|
||||
"from tensorflow_examples.lite.model_maker.core.data_util.image_dataloader import ImageClassifierDataLoader\n",
|
||||
"from tensorflow_examples.lite.model_maker.core.task import image_classifier\n",
|
||||
"from tensorflow_examples.lite.model_maker.core.task.configs import QuantizationConfig\n",
|
||||
"from tensorflow_examples.lite.model_maker.core.task.model_spec import mobilenet_v2_spec\n",
|
||||
"from tensorflow_examples.lite.model_maker.core.task.model_spec import ImageModelSpec\n",
|
||||
"\n",
|
||||
@ -161,7 +162,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"cellView": "form",
|
||||
"colab": {},
|
||||
@ -221,7 +222,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -245,7 +246,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -268,7 +269,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -294,7 +295,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -370,7 +371,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -398,7 +399,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -421,7 +422,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -445,7 +446,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -478,7 +479,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -501,7 +502,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -526,7 +527,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -549,7 +550,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -609,7 +610,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -644,7 +645,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -724,6 +725,83 @@
|
||||
"In this section, we describe several advanced topics, including switching to a different image classification model, changing the training hyperparameters etc.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "Gc4Jk8TvBQfm"
|
||||
},
|
||||
"source": [
|
||||
"## Post-training quantization on the TensorFLow Lite model\n",
|
||||
"\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "tD8BOYrHBiDt"
|
||||
},
|
||||
"source": [
|
||||
"[Post-training quantization](https://www.tensorflow.org/lite/performance/post_training_quantization) is a conversion technique that can reduce model size and inference latency, while also improving CPU and hardware accelerator latency, with little degradation in model accuracy. Thus, it's widely used to optimize the model.\n"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "iyIo0d5TCzE2"
|
||||
},
|
||||
"source": [
|
||||
"Model Maker supports multiple post-training quantization options. Let's take full integer quantization as an instance. First, define the quantization config to enforce enforce full integer quantization for all ops including the input and output. The input type and output type are `uint8` by default. You may also change them to other types like `int8` by setting `inference_input_type` and `inference_output_type` in config."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "k8hL2mstCxQl"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"config = QuantizationConfig.create_full_integer_quantization(representative_data=test_data, is_integer_only=True)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "K1gzx_rmFMOA"
|
||||
},
|
||||
"source": [
|
||||
"Then we export TensorFlow Lite model with such configuration."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
"id": "WTJzFQnJFMjr"
|
||||
},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"model.export(export_dir='.', tflite_filename='model_quant.tflite', quantization_config=config)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
"colab_type": "text",
|
||||
"id": "Safo0e40wKZW"
|
||||
},
|
||||
"source": [
|
||||
"In Colab, you can download the model named `model_quant.tflite` from the left sidebar, same as the uploading part mentioned above."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "markdown",
|
||||
"metadata": {
|
||||
@ -750,7 +828,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -773,7 +851,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -802,7 +880,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -871,7 +949,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
@ -894,7 +972,7 @@
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 0,
|
||||
"execution_count": null,
|
||||
"metadata": {
|
||||
"colab": {},
|
||||
"colab_type": "code",
|
||||
|
@ -1892,61 +1892,70 @@ void NeonCwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
}
|
||||
}
|
||||
|
||||
void NeonCwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input) {
|
||||
const int16x8_t max_dup = vdupq_n_s16(clipping_value);
|
||||
const int16x8_t min_dup = vdupq_n_s16(-clipping_value);
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
int i = 0;
|
||||
for (; i <= n_input - 16; i += 16) {
|
||||
const int index = batch * n_input + i;
|
||||
int16x8_t val_0 = vld1q_s16(input + index);
|
||||
int16x8_t val_1 = vld1q_s16(input + index + 8);
|
||||
val_0 = vminq_s16(val_0, max_dup);
|
||||
val_1 = vminq_s16(val_1, max_dup);
|
||||
val_0 = vmaxq_s16(val_0, min_dup);
|
||||
val_1 = vmaxq_s16(val_1, min_dup);
|
||||
vst1q_s16(input + index, val_0);
|
||||
vst1q_s16(input + index + 8, val_1);
|
||||
}
|
||||
for (; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
if (input[index] > clipping_value) {
|
||||
input[index] = clipping_value;
|
||||
}
|
||||
if (input[index] < -clipping_value) {
|
||||
input[index] = -clipping_value;
|
||||
}
|
||||
}
|
||||
void NeonCwiseClipping(float* vector, const int v_size,
|
||||
const float clipping_value) {
|
||||
const float32x4_t clipping_value_f32x4 = vmovq_n_f32(clipping_value);
|
||||
const float32x4_t neg_clipping_value_f32x4 = vmovq_n_f32(-clipping_value);
|
||||
|
||||
int i = 0;
|
||||
for (; i <= v_size - kFloatValuesPerNeonVector;
|
||||
i += kFloatValuesPerNeonVector) {
|
||||
// Load from memory to vector.
|
||||
float32x4_t v_f32x4 = vld1q_f32(vector + i);
|
||||
// Clip between clipping_value and -clipping_value.
|
||||
v_f32x4 = vminq_f32(clipping_value_f32x4, v_f32x4);
|
||||
v_f32x4 = vmaxq_f32(neg_clipping_value_f32x4, v_f32x4);
|
||||
// Save to output.
|
||||
vst1q_f32(vector + i, v_f32x4);
|
||||
}
|
||||
for (; i < v_size; i++) {
|
||||
vector[i] = std::max(std::min(clipping_value, vector[i]), -clipping_value);
|
||||
}
|
||||
}
|
||||
|
||||
void NeonCwiseClipping(int8_t* input, const int8_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input) {
|
||||
void NeonCwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value) {
|
||||
const int16x8_t max_dup = vdupq_n_s16(clipping_value);
|
||||
const int16x8_t min_dup = vdupq_n_s16(-clipping_value);
|
||||
|
||||
int i = 0;
|
||||
for (; i <= v_size - kInt16ValuesPerNeonVector * 2;
|
||||
i += kInt16ValuesPerNeonVector * 2) {
|
||||
int16x8_t val_0 = vld1q_s16(vector + i);
|
||||
int16x8_t val_1 = vld1q_s16(vector + i + kInt16ValuesPerNeonVector);
|
||||
val_0 = vminq_s16(val_0, max_dup);
|
||||
val_1 = vminq_s16(val_1, max_dup);
|
||||
val_0 = vmaxq_s16(val_0, min_dup);
|
||||
val_1 = vmaxq_s16(val_1, min_dup);
|
||||
vst1q_s16(vector + i, val_0);
|
||||
vst1q_s16(vector + i + kInt16ValuesPerNeonVector, val_1);
|
||||
}
|
||||
for (; i < v_size; i++) {
|
||||
vector[i] = std::max(std::min(clipping_value, vector[i]),
|
||||
static_cast<int16_t>(-clipping_value));
|
||||
}
|
||||
}
|
||||
|
||||
void NeonCwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value) {
|
||||
const int8x16_t max_dup = vdupq_n_s8(clipping_value);
|
||||
const int8x16_t min_dup = vdupq_n_s8(-clipping_value);
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
int i = 0;
|
||||
for (; i <= n_input - 32; i += 32) {
|
||||
const int index = batch * n_input + i;
|
||||
int8x16_t val_0 = vld1q_s8(input + index);
|
||||
int8x16_t val_1 = vld1q_s8(input + index + 16);
|
||||
val_0 = vminq_s8(val_0, max_dup);
|
||||
val_1 = vminq_s8(val_1, max_dup);
|
||||
val_0 = vmaxq_s8(val_0, min_dup);
|
||||
val_1 = vmaxq_s8(val_1, min_dup);
|
||||
vst1q_s8(input + index, val_0);
|
||||
vst1q_s8(input + index + 16, val_1);
|
||||
}
|
||||
for (; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
if (input[index] > clipping_value) {
|
||||
input[index] = clipping_value;
|
||||
}
|
||||
if (input[index] < -clipping_value) {
|
||||
input[index] = -clipping_value;
|
||||
}
|
||||
}
|
||||
|
||||
int i = 0;
|
||||
for (; i < v_size - kInt8ValuesPerNeonVector * 2;
|
||||
i += kInt8ValuesPerNeonVector * 2) {
|
||||
int8x16_t val_0 = vld1q_s8(vector + i);
|
||||
int8x16_t val_1 = vld1q_s8(vector + i + kInt8ValuesPerNeonVector);
|
||||
val_0 = vminq_s8(val_0, max_dup);
|
||||
val_1 = vminq_s8(val_1, max_dup);
|
||||
val_0 = vmaxq_s8(val_0, min_dup);
|
||||
val_1 = vmaxq_s8(val_1, min_dup);
|
||||
vst1q_s8(vector + i, val_0);
|
||||
vst1q_s8(vector + i + kInt8ValuesPerNeonVector, val_1);
|
||||
}
|
||||
for (; i < v_size; i++) {
|
||||
vector[i] = std::max(std::min(clipping_value, vector[i]),
|
||||
static_cast<int8_t>(-clipping_value));
|
||||
}
|
||||
}
|
||||
|
||||
@ -2208,34 +2217,6 @@ bool NeonIsZeroVector(const int8_t* vector, int v_size) {
|
||||
return true;
|
||||
}
|
||||
|
||||
void NeonClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result) {
|
||||
// If v_size is not divisible by the vector size, then we need to process the
|
||||
// final few elements sequentially. postamble_start shows the start index
|
||||
// where this should happen.
|
||||
const int postamble_start =
|
||||
RoundDownVectors<kFloatValuesPerNeonVector>(v_size);
|
||||
|
||||
// Replicate abs_limit and -abs_limit in two vectors.
|
||||
const float32x4_t abs_limit_f32x4 = vmovq_n_f32(abs_limit);
|
||||
const float32x4_t neg_abs_limit_f32x4 = vmovq_n_f32(-abs_limit);
|
||||
|
||||
int v = 0;
|
||||
for (; v < postamble_start; v += kFloatValuesPerNeonVector) {
|
||||
// Load from memory to vector.
|
||||
float32x4_t v_f32x4 = vld1q_f32(vector + v);
|
||||
// Clip between abs_limit and -abs_limit.
|
||||
float32x4_t result_f32x4 = vminq_f32(abs_limit_f32x4, v_f32x4);
|
||||
result_f32x4 = vmaxq_f32(neg_abs_limit_f32x4, result_f32x4);
|
||||
// Save to output.
|
||||
vst1q_f32(result + v, result_f32x4);
|
||||
}
|
||||
// Postamble loop.
|
||||
for (; v < v_size; v++) {
|
||||
result[v] = std::max(std::min(abs_limit, vector[v]), -abs_limit);
|
||||
}
|
||||
}
|
||||
|
||||
void NeonVectorScalarMultiply(const int8_t* vector, const int v_size,
|
||||
const float scale, float* result) {
|
||||
// Here the assumption is that each buffer is 4-byte aligned.
|
||||
|
@ -198,14 +198,17 @@ void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
NEON_OR_PORTABLE(CwiseAdd, input_1, input_2, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
void CwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input) {
|
||||
NEON_OR_PORTABLE(CwiseClipping, input, clipping_value, n_batch, n_input);
|
||||
void CwiseClipping(float* vector, const int v_size,
|
||||
const float clipping_value) {
|
||||
NEON_OR_PORTABLE(CwiseClipping, vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void CwiseClipping(int8_t* input, const int8_t clipping_value, int32_t n_batch,
|
||||
int32_t n_input) {
|
||||
NEON_OR_PORTABLE(CwiseClipping, input, clipping_value, n_batch, n_input);
|
||||
void CwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value) {
|
||||
NEON_OR_PORTABLE(CwiseClipping, vector, v_size, clipping_value);
|
||||
}
|
||||
void CwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value) {
|
||||
NEON_OR_PORTABLE(CwiseClipping, vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
|
||||
@ -255,10 +258,6 @@ void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result) {
|
||||
NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
|
||||
}
|
||||
void ClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result) {
|
||||
NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
|
||||
}
|
||||
|
||||
void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min_value,
|
||||
|
@ -83,11 +83,12 @@ void NeonCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
void NeonCwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
int n_input, int16_t* output);
|
||||
|
||||
void NeonCwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input);
|
||||
|
||||
void NeonCwiseClipping(int8_t* input, const int8_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input);
|
||||
void NeonCwiseClipping(float* vector, const int v_size,
|
||||
const float clipping_value);
|
||||
void NeonCwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value);
|
||||
void NeonCwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value);
|
||||
|
||||
void NeonMatrixBatchVectorMultiplyAccumulate(
|
||||
const int8_t* input, const int32_t* bias,
|
||||
@ -133,10 +134,6 @@ void NeonSub1Vector(const float* vector, int v_size, float* result);
|
||||
|
||||
void NeonSub1Vector(const int16_t* vector, int v_size, int16_t* result);
|
||||
|
||||
// Clip elements of a vector using a abs_limit value.
|
||||
void NeonClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result);
|
||||
|
||||
// Multiply all elements of vector with a scalar.
|
||||
void NeonVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result);
|
||||
|
@ -206,14 +206,19 @@ void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
PortableCwiseAdd(input_1, input_2, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
void CwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input) {
|
||||
PortableCwiseClipping(input, clipping_value, n_batch, n_input);
|
||||
void CwiseClipping(float* vector, const int v_size,
|
||||
const float clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void CwiseClipping(int8_t* input, const int8_t clipping_value, int32_t n_batch,
|
||||
int32_t n_input) {
|
||||
PortableCwiseClipping(input, clipping_value, n_batch, n_input);
|
||||
void CwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void CwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void BatchVectorBatchVectorDotProduct(const int16_t* vector1,
|
||||
@ -263,10 +268,6 @@ void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result) {
|
||||
NEON_OR_PORTABLE(VectorScalarMultiply, vector, v_size, scale, result);
|
||||
}
|
||||
void ClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result) {
|
||||
NEON_OR_PORTABLE(ClipVector, vector, v_size, abs_limit, result);
|
||||
}
|
||||
|
||||
void SymmetricQuantizeFloats(const float* values, const int size,
|
||||
int8_t* quantized_values, float* min_value,
|
||||
|
@ -651,36 +651,6 @@ void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2,
|
||||
}
|
||||
}
|
||||
|
||||
void PortableCwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input) {
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
if (input[index] > clipping_value) {
|
||||
input[index] = clipping_value;
|
||||
}
|
||||
if (input[index] < -clipping_value) {
|
||||
input[index] = -clipping_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
void PortableCwiseClipping(int8_t* input, const int8_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input) {
|
||||
for (int batch = 0; batch < n_batch; ++batch) {
|
||||
for (int i = 0; i < n_input; ++i) {
|
||||
const int index = batch * n_input + i;
|
||||
if (input[index] > clipping_value) {
|
||||
input[index] = clipping_value;
|
||||
}
|
||||
if (input[index] < -clipping_value) {
|
||||
input[index] = -clipping_value;
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
float PortableVectorVectorDotProduct(const float* vector1, const float* vector2,
|
||||
int v_size) {
|
||||
float result = 0.0;
|
||||
@ -757,13 +727,6 @@ void PortableVectorScalarMultiply(const int8_t* vector, const int v_size,
|
||||
}
|
||||
}
|
||||
|
||||
void PortableClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result) {
|
||||
for (int v = 0; v < v_size; v++) {
|
||||
result[v] = std::max(std::min(abs_limit, vector[v]), -abs_limit);
|
||||
}
|
||||
}
|
||||
|
||||
void PortableReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
const float* input_vector_ptr = input_vector;
|
||||
|
@ -230,14 +230,19 @@ void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
PortableCwiseAdd(input_1, input_2, n_batch, n_input, output);
|
||||
}
|
||||
|
||||
void CwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input) {
|
||||
PortableCwiseClipping(input, clipping_value, n_batch, n_input);
|
||||
void CwiseClipping(float* vector, const int v_size,
|
||||
const float clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void CwiseClipping(int8_t* input, const int8_t clipping_value, int32_t n_batch,
|
||||
int32_t n_input) {
|
||||
PortableCwiseClipping(input, clipping_value, n_batch, n_input);
|
||||
void CwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void CwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value) {
|
||||
PortableCwiseClipping(vector, v_size, clipping_value);
|
||||
}
|
||||
|
||||
void VectorBatchVectorCwiseProductAccumulate(const int16_t* vector, int v_size,
|
||||
@ -279,11 +284,6 @@ void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
PortableVectorScalarMultiply(vector, v_size, scale, result);
|
||||
}
|
||||
|
||||
void ClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result) {
|
||||
PortableClipVector(vector, v_size, abs_limit, result);
|
||||
}
|
||||
|
||||
void ReductionSumVector(const float* input_vector, float* output_vector,
|
||||
int output_size, int reduction_size) {
|
||||
PortableReductionSumVector(input_vector, output_vector, output_size,
|
||||
|
@ -15,6 +15,7 @@ limitations under the License.
|
||||
#ifndef TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_IMPL_H_
|
||||
#define TENSORFLOW_LITE_KERNELS_INTERNAL_REFERENCE_PORTABLE_TENSOR_UTILS_IMPL_H_
|
||||
|
||||
#include <algorithm>
|
||||
#include <cstdint>
|
||||
|
||||
// TODO(ghodrat): Remove this header file and the dependency to internal data
|
||||
@ -33,9 +34,6 @@ class CpuBackendContext;
|
||||
|
||||
namespace tensor_utils {
|
||||
|
||||
// Limit a float input f between +abs_limit and -abs_limit.
|
||||
float PortableClip(float f, float abs_limit);
|
||||
|
||||
template <typename T>
|
||||
bool PortableIsZeroVector(const T* vector, int v_size) {
|
||||
for (int i = 0; i < v_size; ++i) {
|
||||
@ -178,11 +176,14 @@ void PortableCwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
void PortableCwiseAdd(const int16_t* input_1, const int16_t* input_2,
|
||||
int n_batch, int n_input, int16_t* output);
|
||||
|
||||
void PortableCwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input);
|
||||
|
||||
void PortableCwiseClipping(int8_t* input, const int8_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input);
|
||||
template <typename T>
|
||||
void PortableCwiseClipping(T* vector, const int v_size,
|
||||
const T clipping_value) {
|
||||
for (int i = 0; i < v_size; i++) {
|
||||
vector[i] = std::max(std::min(clipping_value, vector[i]),
|
||||
static_cast<T>(-clipping_value));
|
||||
}
|
||||
}
|
||||
|
||||
// Batch vector initialization with another vector.
|
||||
void PortableVectorBatchVectorAssign(const float* vector, int v_size,
|
||||
@ -201,10 +202,6 @@ void PortableSub1Vector(const int16_t* vector, int v_size, int16_t* result);
|
||||
void PortableVectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result);
|
||||
|
||||
// Clip elements of a vector using a abs_limit value.
|
||||
void PortableClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result);
|
||||
|
||||
// Reduce-sum on a float input vector:
|
||||
// input_vector: float pointer to input vector.
|
||||
// output_vector: float pointer to vector.
|
||||
|
@ -406,23 +406,16 @@ void CwiseMul(const int16_t* input_1, const int16_t* input_2,
|
||||
void CwiseAdd(const int16_t* input_1, const int16_t* input_2, int n_batch,
|
||||
int n_input, int16_t* output);
|
||||
|
||||
// Element-wise in-place clipping of a quantized vector.
|
||||
// Parameters:
|
||||
// - input: batch vector of size n_batch * n_input; 16 bit.
|
||||
// Element-wise in-place clipping of a vector. Overloaded for float, int16_t,
|
||||
// int8_t. Parameters:
|
||||
// - vector: vector of size v_size.
|
||||
// - v_size: the size of the vector.
|
||||
// - clipping_value: the value used for clipping.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
void CwiseClipping(int16_t* input, const int16_t clipping_value,
|
||||
int32_t n_batch, int32_t n_input);
|
||||
|
||||
// Element-wise in-place clipping of a quantized vector.
|
||||
// Parameters:
|
||||
// - input: batch vector of size n_batch * n_input; 8 bit.
|
||||
// - clipping_value: the value used for clipping.
|
||||
// - n_batch: the number of batches.
|
||||
// - n_input: the size for input and output.
|
||||
void CwiseClipping(int8_t* input, const int8_t clipping_value, int32_t n_batch,
|
||||
int32_t n_input);
|
||||
void CwiseClipping(float* vector, const int v_size, const float clipping_value);
|
||||
void CwiseClipping(int16_t* vector, const int v_size,
|
||||
const int16_t clipping_value);
|
||||
void CwiseClipping(int8_t* vector, const int v_size,
|
||||
const int8_t clipping_value);
|
||||
|
||||
// Cwise product of two vectors.
|
||||
template <typename T>
|
||||
@ -611,10 +604,6 @@ void Sub1Vector(const int16_t* vector, int v_size, int16_t* result);
|
||||
void VectorScalarMultiply(const int8_t* vector, int v_size, float scale,
|
||||
float* result);
|
||||
|
||||
// Clip elements of a vector using a abs_limit value.
|
||||
void ClipVector(const float* vector, int v_size, float abs_limit,
|
||||
float* result);
|
||||
|
||||
// Reduce-sum on a float input vector:
|
||||
// input_vector: float pointer to input vector.
|
||||
// output_vector: float pointer to vector.
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user