Merge branch 'master' into macos_arm64_cmake
This commit is contained in:
commit
8817e44105
tensorflow
c
eager
BUILDc_api_test_util.hgradient_checker.ccgradient_checker_test.ccgradients_test.ccgradients_util.ccgradients_util.hmnist_gradients_test.ccmnist_gradients_testutil.ccmnist_gradients_testutil.hunified_api_test.ccunified_api_testutil.ccunified_api_testutil.h
experimental/gradients
compiler
mlir
hlo
lite
ir
tests
transforms
python
tensorflow
tests/tf_saved_model
transforms
utils
xla
tf2tensorrt
tf2xla
xla
core
BUILD
api_def
base_api
api_def_TensorForestCreateTreeVariable.pbtxtapi_def_TensorForestTreeDeserialize.pbtxtapi_def_TensorForestTreeIsInitializedOp.pbtxtapi_def_TensorForestTreePredict.pbtxtapi_def_TensorForestTreeResourceHandleOp.pbtxtapi_def_TensorForestTreeSerialize.pbtxtapi_def_TensorForestTreeSize.pbtxt
java_api
framework
kernels
BUILD
hexagon
BUILDgraph_transfer_utils.ccgraph_transfer_utils.hgraph_transferer.ccgraph_transferer.hgraph_transferer_test.cchexagon_control_wrapper.cchexagon_control_wrapper.hhexagon_graph_execution_test.cchexagon_ops_definitions.cchexagon_ops_definitions.hhexagon_remote_fused_graph_executor_build.cchexagon_remote_fused_graph_executor_build_test.cchexagon_rewriter_transform.cc
@ -3,7 +3,6 @@
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"if_libtpu",
|
||||
"tf_cc_test",
|
||||
"tf_copts",
|
||||
"tf_cuda_cc_test",
|
||||
@ -320,75 +319,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradients_util",
|
||||
srcs = [
|
||||
"gradients_util.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"gradients_util.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_context",
|
||||
":abstract_operation",
|
||||
":abstract_tensor_handle",
|
||||
":c_api",
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":tape",
|
||||
"@com_google_absl//absl/container:flat_hash_map",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "mnist_gradients_testutil",
|
||||
srcs = [
|
||||
"mnist_gradients_testutil.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"mnist_gradients_testutil.h",
|
||||
],
|
||||
visibility = [
|
||||
"//tensorflow:internal",
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":gradients_util",
|
||||
":tape",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:status",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "gradient_checker",
|
||||
testonly = 1,
|
||||
@ -436,46 +366,6 @@ tf_cuda_cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "mnist_gradients_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"mnist_gradients_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"],
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + [
|
||||
"no_cuda_asan", # b/173825513
|
||||
],
|
||||
deps = [
|
||||
":abstract_tensor_handle",
|
||||
":c_api_experimental",
|
||||
":c_api_unified_internal",
|
||||
":gradients_internal",
|
||||
":gradients_util",
|
||||
":mnist_gradients_testutil",
|
||||
"//tensorflow/c:c_api",
|
||||
"//tensorflow/c:c_test_util",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/experimental/gradients:math_grad",
|
||||
"//tensorflow/c/experimental/gradients:nn_grad",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/c/experimental/ops:math_ops",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/cc/profiler",
|
||||
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/lib/llvm_rtti",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"@com_google_absl//absl/container:flat_hash_set",
|
||||
"@com_google_absl//absl/strings",
|
||||
"@com_google_absl//absl/types:span",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "abstract_tensor_handle",
|
||||
srcs = ["abstract_tensor_handle.cc"],
|
||||
@ -1157,8 +1047,6 @@ filegroup(
|
||||
"gradient_checker.cc",
|
||||
"gradient_checker.h",
|
||||
"gradients.cc", # Uses RTTI.
|
||||
"gradients_util.cc",
|
||||
"gradients_util.h",
|
||||
"tracing_utils.h",
|
||||
"tracing_utils.cc",
|
||||
"*test*",
|
||||
|
@ -16,6 +16,9 @@ limitations under the License.
|
||||
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
|
||||
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/tf_datatype.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/tstring.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
|
||||
@ -53,6 +56,27 @@ TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
|
||||
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Return a tensor handle with given type, values and dimensions.
|
||||
template <class T, TF_DataType datatype>
|
||||
TFE_TensorHandle* TestTensorHandleWithDims(TFE_Context* ctx, const T* data,
|
||||
const int64_t* dims, int num_dims) {
|
||||
TF_Status* status = TF_NewStatus();
|
||||
TF_Tensor* t = TFE_AllocateHostTensor(ctx, datatype, dims, num_dims, status);
|
||||
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
|
||||
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
|
||||
TF_DeleteTensor(t);
|
||||
TF_DeleteStatus(status);
|
||||
return th;
|
||||
}
|
||||
|
||||
// Return a scalar tensor handle with given values.
|
||||
template <class T, TF_DataType datatype>
|
||||
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, const T value) {
|
||||
T data[] = {value};
|
||||
return TestTensorHandleWithDims<T, datatype>(ctx, data, nullptr, 0);
|
||||
}
|
||||
|
||||
// Return a tensor handle containing a 100x100 matrix of floats
|
||||
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);
|
||||
|
||||
|
@ -29,8 +29,9 @@ using namespace std;
|
||||
// ================== Helper functions =================
|
||||
|
||||
// Fills data with values [start,end) with given step size.
|
||||
void Range(vector<int>* data, int start, int end, int step = 1) {
|
||||
for (int i = start; i < end; i += step) {
|
||||
void Range(vector<int32_t>* data, int32_t start, int32_t end,
|
||||
int32_t step = 1) {
|
||||
for (int32_t i = start; i < end; i += step) {
|
||||
(*data)[i] = i;
|
||||
}
|
||||
}
|
||||
@ -72,12 +73,12 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
|
||||
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
|
||||
AbstractTensorHandlePtr sum_dims;
|
||||
{
|
||||
vector<int> vals(num_dims_out);
|
||||
vector<int32_t> vals(num_dims_out);
|
||||
int64_t vals_shape[] = {num_dims_out};
|
||||
Range(&vals, 0, num_dims_out);
|
||||
AbstractTensorHandle* sum_dims_raw = nullptr;
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsInt(ctx, vals.data(), vals_shape,
|
||||
1, &sum_dims_raw));
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<int32_t, TF_INT32>(
|
||||
ctx, vals.data(), vals_shape, 1, &sum_dims_raw));
|
||||
sum_dims.reset(sum_dims_raw);
|
||||
}
|
||||
|
||||
@ -130,8 +131,8 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
AbstractTensorHandlePtr two_eps;
|
||||
{
|
||||
AbstractTensorHandle* two_eps_raw = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
TestScalarTensorHandle(ctx, 2 * epsilon, &two_eps_raw));
|
||||
TF_RETURN_IF_ERROR(TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
ctx, 2 * epsilon, &two_eps_raw));
|
||||
two_eps.reset(two_eps_raw);
|
||||
}
|
||||
|
||||
@ -142,7 +143,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
AbstractTensorHandlePtr thetaPlus;
|
||||
{
|
||||
AbstractTensorHandle* thetaPlus_raw = nullptr;
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims,
|
||||
&thetaPlus_raw));
|
||||
thetaPlus.reset(thetaPlus_raw);
|
||||
@ -155,7 +156,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
AbstractTensorHandlePtr thetaMinus;
|
||||
{
|
||||
AbstractTensorHandle* thetaMinus_raw = nullptr;
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims,
|
||||
&thetaMinus_raw));
|
||||
thetaMinus.reset(thetaMinus_raw);
|
||||
@ -194,7 +195,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
|
||||
}
|
||||
|
||||
// Populate *numerical_grad with the data from dtheta_approx.
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
|
||||
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
|
||||
TF_DeleteTensor(theta_tensor);
|
||||
return Status::OK();
|
||||
|
@ -117,8 +117,8 @@ TEST_P(GradientCheckerTest, TestMatMul) {
|
||||
AbstractTensorHandlePtr A;
|
||||
{
|
||||
AbstractTensorHandle* A_raw;
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
|
||||
Status s = TestTensorHandleWithDims<float, TF_FLOAT>(ctx_.get(), A_vals,
|
||||
A_dims, 2, &A_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
A.reset(A_raw);
|
||||
}
|
||||
@ -127,8 +127,8 @@ TEST_P(GradientCheckerTest, TestMatMul) {
|
||||
AbstractTensorHandlePtr B;
|
||||
{
|
||||
AbstractTensorHandle* B_raw;
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
|
||||
Status s = TestTensorHandleWithDims<float, TF_FLOAT>(ctx_.get(), B_vals,
|
||||
B_dims, 2, &B_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
B.reset(B_raw);
|
||||
}
|
||||
@ -143,7 +143,8 @@ TEST_P(GradientCheckerTest, TestMul) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
|
||||
Status s =
|
||||
TestScalarTensorHandle<float, TF_FLOAT>(ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -151,7 +152,8 @@ TEST_P(GradientCheckerTest, TestMul) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw);
|
||||
Status s =
|
||||
TestScalarTensorHandle<float, TF_FLOAT>(ctx_.get(), 7.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
@ -58,100 +58,10 @@ class CppGradients
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
|
||||
// AddV2Registerer.
|
||||
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
|
||||
TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics"));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// ignored, y = IdentityN(inputs[0], inputs[1])
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
// This should return [nullptr, 1].
|
||||
Status IdentityNGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(RegisterGradients(®istry));
|
||||
|
||||
auto tape = std::make_unique<Tape>(/*persistent=*/false);
|
||||
tape->Watch(inputs[0]);
|
||||
tape->Watch(inputs[1]);
|
||||
|
||||
vector<AbstractTensorHandle*> identity_n_outputs(2);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
|
||||
TF_RETURN_IF_ERROR(ops::IdentityN(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
|
||||
/*targets=*/{identity_n_outputs[1]},
|
||||
/*sources=*/{inputs[0], inputs[1]},
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto identity_n_output : identity_n_outputs) {
|
||||
identity_n_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// Pseudo-code:
|
||||
//
|
||||
// tape.watch(x1)
|
||||
// tape.watch(x2)
|
||||
// unused, y = IdentityN([x1, x2])
|
||||
// outputs = tape.gradient(y, [x1, x2])
|
||||
// Expected: [nullptr, 1]
|
||||
//
|
||||
// This test is interesting because the current implementation of GradientTape
|
||||
// would return [0, 1] whereas we use build_default_zeros_grads=false here
|
||||
// so we get back [nullptr, 1].
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x1;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x1.reset(x_raw);
|
||||
}
|
||||
AbstractTensorHandlePtr x2;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x2.reset(x_raw);
|
||||
}
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()));
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
EXPECT_EQ(outputs[0], nullptr);
|
||||
TF_Tensor* result_tensor;
|
||||
s = GetValue(outputs[1], &result_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
|
||||
EXPECT_EQ(*result_value, 1.0);
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(result_tensor);
|
||||
result_tensor = nullptr;
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSetAttrString) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
@ -167,7 +77,7 @@ TEST_P(CppGradients, TestSetAttrString) {
|
||||
AbstractTensorHandlePtr t;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
t.reset(x_raw);
|
||||
}
|
||||
@ -232,7 +142,7 @@ TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
@ -1,320 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
using namespace std;
|
||||
|
||||
Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
|
||||
TFE_TensorHandle** result) {
|
||||
float data[] = {value};
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
|
||||
*result = th;
|
||||
TF_DeleteTensor(t);
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
TFE_TensorHandle** result) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
|
||||
*result = th;
|
||||
TF_DeleteTensor(t);
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
|
||||
int64_t dims[], int num_dims,
|
||||
TFE_TensorHandle** result) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_Tensor* t =
|
||||
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
|
||||
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
|
||||
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
|
||||
*result = th;
|
||||
TF_DeleteTensor(t);
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager;
|
||||
TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager;
|
||||
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
|
||||
num_dims, &input_eager));
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
|
||||
int num_dims, AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager;
|
||||
TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
|
||||
num_dims, &input_eager));
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_TensorHandle* result_t =
|
||||
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
|
||||
return StatusFromTF_Status(status.get());
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims) {
|
||||
AbstractTensorHandlePtr A;
|
||||
AbstractTensorHandle* a_raw = nullptr;
|
||||
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
|
||||
if (s.ok()) {
|
||||
A.reset(a_raw);
|
||||
}
|
||||
return A;
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val) {
|
||||
AbstractTensorHandlePtr y;
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
Status s = ScalarTensorHandle(ctx, val, &y_raw);
|
||||
if (s.ok()) {
|
||||
y.reset(y_raw);
|
||||
}
|
||||
return y;
|
||||
}
|
||||
|
||||
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
|
||||
vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate) {
|
||||
/* Update weights one by one using gradient update rule:
|
||||
*
|
||||
* w -= lr*grad[w]
|
||||
*
|
||||
* NOTE: assuming learning rate is positive
|
||||
*/
|
||||
|
||||
int num_grads = grads.size();
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
std::string update_str;
|
||||
|
||||
// Negate learning rate for gradient descent
|
||||
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"neg_lr")); // Compute -lr
|
||||
learning_rate = temp_outputs[0];
|
||||
|
||||
for (int i = 0; i < num_grads; i++) {
|
||||
// Compute dW = -lr * grad(w[i])
|
||||
update_str = "update_mul_" + std::to_string(i);
|
||||
TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str()));
|
||||
|
||||
AbstractTensorHandle* dW = temp_outputs[0];
|
||||
|
||||
// Compute temp = weights[i] + dW
|
||||
update_str = "update_add_" + std::to_string(i);
|
||||
TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
update_str.c_str()));
|
||||
|
||||
// Update the weights
|
||||
weights[i] = temp_outputs[0];
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
AbstractContext* BuildFunction(const char* fn_name) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
|
||||
return unwrap(graph_ctx);
|
||||
}
|
||||
|
||||
Status CreateParamsForInputs(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
vector<AbstractTensorHandle*>* params) {
|
||||
tracing::TracingTensorHandle* handle = nullptr;
|
||||
for (auto input : inputs) {
|
||||
PartialTensorShape shape;
|
||||
TF_RETURN_IF_ERROR(input->Shape(&shape));
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
|
||||
input->DataType(), shape, &handle));
|
||||
params->emplace_back(handle);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry) {
|
||||
if (use_function) {
|
||||
const char* fn_name = "test_fn";
|
||||
std::unique_ptr<AbstractFunction> scoped_func;
|
||||
// Returning null tensors from a tf.function is not supported, so we keep
|
||||
// track of indices in the model's outputs are nullptr in this set.
|
||||
// The FunctionDef only outputs the non-null tensors. We later pad the
|
||||
// function op outputs to have nullptrs at the `null_indices`.
|
||||
absl::flat_hash_set<int> null_indices;
|
||||
{
|
||||
AbstractContextPtr func_ctx(BuildFunction(fn_name));
|
||||
vector<AbstractTensorHandle*> func_inputs;
|
||||
func_inputs.reserve(inputs.size());
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
|
||||
vector<AbstractTensorHandle*> model_outputs;
|
||||
model_outputs.resize(outputs.size());
|
||||
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
|
||||
absl::MakeSpan(model_outputs), registry));
|
||||
for (auto func_input : func_inputs) {
|
||||
func_input->Unref();
|
||||
}
|
||||
AbstractFunction* func = nullptr;
|
||||
OutputList output_list;
|
||||
output_list.expected_num_outputs = 0;
|
||||
output_list.outputs.reserve(outputs.size());
|
||||
for (int i = 0; i < model_outputs.size(); i++) {
|
||||
if (model_outputs[i]) {
|
||||
output_list.outputs.emplace_back(model_outputs[i]);
|
||||
output_list.expected_num_outputs += 1;
|
||||
} else {
|
||||
null_indices.insert(i);
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
|
||||
->Finalize(&output_list, &func));
|
||||
scoped_func.reset(func);
|
||||
for (auto output : output_list.outputs) {
|
||||
output->Unref();
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
|
||||
}
|
||||
|
||||
AbstractOperationPtr fn_op(ctx->CreateOperation());
|
||||
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
|
||||
for (auto input : inputs) {
|
||||
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
|
||||
}
|
||||
int retvals = outputs.size() - null_indices.size();
|
||||
vector<AbstractTensorHandle*> fn_outputs(retvals);
|
||||
TF_RETURN_IF_ERROR(fn_op->Execute(
|
||||
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
|
||||
&retvals));
|
||||
int skipped_indices = 0;
|
||||
for (int i = 0; i < outputs.size(); i++) {
|
||||
if (!null_indices.contains(i)) {
|
||||
outputs[i] = fn_outputs[i - skipped_indices];
|
||||
} else {
|
||||
skipped_indices += 1;
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
|
||||
return Status::OK();
|
||||
} else {
|
||||
return model(ctx, inputs, outputs, registry);
|
||||
}
|
||||
}
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_ContextOptions* opts = TFE_NewContextOptions();
|
||||
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
|
||||
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_DeleteContextOptions(opts);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -1,88 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
|
||||
// Get a scalar TensorHandle with given value
|
||||
Status ScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given float values and dimensions
|
||||
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
|
||||
int64_t dims[], int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
|
||||
int num_dims, AbstractTensorHandle** tensor);
|
||||
|
||||
// Places data from `t` into *result_tensor.
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data and dims.
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
|
||||
float vals[], int64_t dims[],
|
||||
int num_dims);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data and dims.
|
||||
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
|
||||
int64_t dims[], int num_dims);
|
||||
|
||||
// Util function that wraps an AbstractTensorHandle* with given data.
|
||||
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
|
||||
float val);
|
||||
|
||||
// Performs gradient update for each weight using given learning rate.
|
||||
Status UpdateWeights(AbstractContext* ctx,
|
||||
std::vector<AbstractTensorHandle*>& grads,
|
||||
std::vector<AbstractTensorHandle*>& weights,
|
||||
AbstractTensorHandle* learning_rate);
|
||||
|
||||
using Model = std::function<Status(
|
||||
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
|
||||
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
|
||||
|
||||
// Runs given model in either graph or eager mode depending on value of
|
||||
// use_function.
|
||||
Status RunModel(Model model, AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Builds context and returns inside *ctx.
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -1,602 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/math_grad.h"
|
||||
#include "tensorflow/c/experimental/gradients/nn_grad.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/errors.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
using tensorflow::TF_StatusPtr;
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
CHECK_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
|
||||
// low tolerances
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
};
|
||||
|
||||
Status RegisterGradients(GradientRegistry* registry) {
|
||||
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
|
||||
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMatMulGrad) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims[] = {2, 2};
|
||||
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
|
||||
int64_t B_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A =
|
||||
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
AbstractTensorHandlePtr B =
|
||||
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
* tape.watch(A)
|
||||
* tape.watch(B)
|
||||
* Y = AB
|
||||
* outputs = tape.gradient(Y, [A, B])
|
||||
*/
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = GetValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dA_tensor),
|
||||
TF_TensorByteSize(dA_tensor));
|
||||
|
||||
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* dB_tensor;
|
||||
s = GetValue(outputs[1], &dB_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(dB_tensor),
|
||||
TF_TensorByteSize(dB_tensor));
|
||||
|
||||
float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
TF_DeleteTensor(dB_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTForward) {
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1};
|
||||
int64_t dims_y[] = {2};
|
||||
num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
|
||||
// Run the Forward Pass
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s =
|
||||
RunModel(MNISTForwardModel, ctx.get(),
|
||||
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
TF_Tensor* scores_tensor;
|
||||
s = GetValue(outputs[0], &scores_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor),
|
||||
TF_TensorByteSize(scores_tensor));
|
||||
|
||||
float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* loss_vals_tensor;
|
||||
s = GetValue(outputs[1], &loss_vals_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
|
||||
TF_TensorByteSize(loss_vals_tensor));
|
||||
float expected_losses[2] = {9.6f, 27.2f};
|
||||
for (int j = 0; j < 2; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(scores_tensor);
|
||||
TF_DeleteTensor(loss_vals_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTForward2) {
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
int64_t X_dims[] = {3, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1, 1};
|
||||
int64_t y_dims[] = {3};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
|
||||
// Run the Forward Pass
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
Status s =
|
||||
RunModel(MNISTForwardModel, ctx.get(),
|
||||
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
TF_Tensor* scores_tensor;
|
||||
s = GetValue(outputs[0], &scores_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[6] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor),
|
||||
TF_TensorByteSize(scores_tensor));
|
||||
|
||||
float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 6; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* loss_vals_tensor;
|
||||
s = GetValue(outputs[1], &loss_vals_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
|
||||
TF_TensorByteSize(loss_vals_tensor));
|
||||
float expected_losses[3] = {9.6f, 27.2f, 44.8f};
|
||||
for (int j = 0; j < 3; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
TF_DeleteTensor(scores_tensor);
|
||||
TF_DeleteTensor(loss_vals_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMatMulTranspose) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
|
||||
int64_t X_dims[] = {2, 3};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
|
||||
// Run the MatMul Op
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
|
||||
Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Verify the Results
|
||||
TF_Tensor* scores_tensor;
|
||||
s = GetValue(outputs[0], &scores_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[6] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(scores_tensor),
|
||||
TF_TensorByteSize(scores_tensor));
|
||||
|
||||
float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
|
||||
float tolerance = 1e-3;
|
||||
for (int j = 0; j < 6; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNISTGrad) {
|
||||
bool use_function = !std::get<2>(GetParam());
|
||||
if (use_function) {
|
||||
// TODO(b/168850692): Enable this.
|
||||
GTEST_SKIP() << "Can't take gradient of "
|
||||
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t X_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1};
|
||||
int64_t y_dims[] = {2};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
// Register Grads
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
/* Pseudo-code:
|
||||
*
|
||||
*
|
||||
* tape.watch(W1)
|
||||
* tape.watch(W2)
|
||||
* mm = X*W1
|
||||
* hidden = Relu(mm)
|
||||
* scores = W2*hidden
|
||||
* loss = SoftmaxLoss(scores, y)
|
||||
* outputs = tape.gradient(loss, [A, B])
|
||||
*
|
||||
*/
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(3);
|
||||
s = RunModel(MNISTGradModel, ctx.get(),
|
||||
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float tolerance = 1e-3;
|
||||
TF_Tensor* dW1_tensor;
|
||||
s = GetValue(outputs[0], &dW1_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dW1_tensor),
|
||||
TF_TensorByteSize(dW1_tensor));
|
||||
|
||||
float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
|
||||
}
|
||||
|
||||
TF_Tensor* dW2_tensor;
|
||||
s = GetValue(outputs[1], &dW2_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
memcpy(&result_data[0], TF_TensorData(dW2_tensor),
|
||||
TF_TensorByteSize(dW2_tensor));
|
||||
|
||||
float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
outputs[2]->Unref();
|
||||
TF_DeleteTensor(dW1_tensor);
|
||||
TF_DeleteTensor(dW2_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestScalarMul) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr eta;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
eta.reset(x_raw);
|
||||
}
|
||||
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t A_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
|
||||
AbstractTensorHandlePtr A =
|
||||
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
|
||||
|
||||
GradientRegistry registry;
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
|
||||
absl::MakeSpan(outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
TF_Tensor* dA_tensor;
|
||||
s = GetValue(outputs[0], &dA_tensor);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
float result_data[4] = {0};
|
||||
memcpy(&result_data[0], TF_TensorData(dA_tensor),
|
||||
TF_TensorByteSize(dA_tensor));
|
||||
|
||||
float tolerance = 1e-3;
|
||||
float eta_val = 1.5f;
|
||||
for (int j = 0; j < 4; j++) {
|
||||
ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
|
||||
}
|
||||
|
||||
outputs[0]->Unref();
|
||||
TF_DeleteTensor(dA_tensor);
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMNIST_Training) {
|
||||
bool use_function = !std::get<2>(GetParam());
|
||||
if (use_function) {
|
||||
// TODO(b/168850692): Enable this.
|
||||
GTEST_SKIP() << "Can't take gradient of "
|
||||
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
|
||||
}
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
||||
AbstractContextPtr ctx;
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// X = data
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
|
||||
int64_t X_dims[] = {2, 2};
|
||||
int num_dims = 2;
|
||||
AbstractTensorHandlePtr X =
|
||||
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
|
||||
|
||||
// TODO(amturati): use random initializer for weights instead of
|
||||
// constant values.
|
||||
|
||||
// W1 = first weights
|
||||
float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
|
||||
int64_t dims[] = {2, 2};
|
||||
AbstractTensorHandlePtr W1 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
|
||||
|
||||
// W2 = second weights
|
||||
float W2_vals[] = {.1f, .2f, .3f, -.5f};
|
||||
AbstractTensorHandlePtr W2 =
|
||||
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
|
||||
|
||||
// y = labels
|
||||
int y_vals[] = {1, 1};
|
||||
int64_t y_dims[] = {2};
|
||||
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
|
||||
AbstractTensorHandlePtr y =
|
||||
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
|
||||
|
||||
// Register Grads
|
||||
GradientRegistry registry;
|
||||
Status s = RegisterGradients(®istry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Prepare for training
|
||||
std::vector<AbstractTensorHandle*> weights;
|
||||
weights.push_back(W1.get());
|
||||
weights.push_back(W2.get());
|
||||
|
||||
// Set learning rate to be 1e-1
|
||||
AbstractTensorHandle* learning_rate = nullptr;
|
||||
s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Train
|
||||
int num_iters = 10;
|
||||
std::vector<AbstractTensorHandle*> mnist_outputs(3);
|
||||
std::vector<AbstractTensorHandle*> grads(2);
|
||||
for (int i = 0; i < num_iters; i++) {
|
||||
// Run Forward Pass
|
||||
s = RunModel(MNISTGradModel, ctx.get(),
|
||||
{X.get(), weights[0], weights[1], y.get()},
|
||||
absl::MakeSpan(mnist_outputs),
|
||||
/*use_function=*/!std::get<2>(GetParam()), registry);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
|
||||
// Fill grads
|
||||
grads[0] = mnist_outputs[0];
|
||||
grads[1] = mnist_outputs[1];
|
||||
|
||||
// Gradient Update
|
||||
s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
}
|
||||
|
||||
grads[0]->Unref(); // release W1_grad
|
||||
grads[1]->Unref(); // release W2_grad
|
||||
mnist_outputs[2]->Unref(); // release loss
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*executing_eagerly*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -1,244 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/container/flat_hash_set.h"
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/eager/gradients_util.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
using std::vector;
|
||||
|
||||
//===================== Test Models to run =========================
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
std::vector<AbstractTensorHandle*> add_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto add_output : add_outputs) {
|
||||
add_output->Unref();
|
||||
}
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MatMulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
auto tape = new Tape(/*persistent=*/false);
|
||||
tape->Watch(inputs[0]); // Watch x.
|
||||
tape->Watch(inputs[1]); // Watch y.
|
||||
vector<AbstractTensorHandle*> mm_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(mm_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute x*y.
|
||||
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mm_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto mm_output : mm_outputs) {
|
||||
mm_output->Unref();
|
||||
}
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Model to run 2-layer net
|
||||
Status MNISTForwardModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
/**
|
||||
* We will trace a 2-layer fully connected network for an MNIST model:
|
||||
*
|
||||
* def mnist_forward(X, W1, W2, y_labels):
|
||||
* mm_out_1 = tf.matmul(X,W1)
|
||||
* hidden_layer = tf.nn.relu(mm_out_1)
|
||||
* scores = tf.matmul(hidden_layer,W2)
|
||||
* softmax =
|
||||
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
|
||||
* y_labels)
|
||||
* return scores, softmax
|
||||
*
|
||||
* Use this convention for inputs:
|
||||
*
|
||||
* inputs = [X, W1, W2, y_labels]
|
||||
*
|
||||
*/
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
AbstractTensorHandle* W2 = inputs[2];
|
||||
AbstractTensorHandle* y_labels = inputs[3];
|
||||
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, absl::MakeSpan(temp_outputs),
|
||||
"matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::Relu(ctx, {temp_outputs[0]},
|
||||
absl::MakeSpan(temp_outputs),
|
||||
"relu")); // Compute Relu(X*W1)
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
ctx, {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
ctx, {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmax_loss")); // Compute Softmax(Scores,labels)
|
||||
|
||||
AbstractTensorHandle* loss_vals = temp_outputs[0];
|
||||
|
||||
outputs[0] = scores;
|
||||
outputs[1] = loss_vals;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, outputs, "matmul0",
|
||||
/*transpose_a=*/true,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status MNISTGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
AbstractTensorHandle* X = inputs[0];
|
||||
AbstractTensorHandle* W1 = inputs[1];
|
||||
AbstractTensorHandle* W2 = inputs[2];
|
||||
AbstractTensorHandle* y_labels = inputs[3];
|
||||
|
||||
auto tape = new Tape(/*persistent=*/true);
|
||||
tape->Watch(X); // Watch X.
|
||||
tape->Watch(W1); // Watch W1.
|
||||
tape->Watch(W2); // Watch W1.
|
||||
vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
|
||||
absl::MakeSpan(temp_outputs), "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false)); // Compute X*W1
|
||||
|
||||
AbstractTensorHandle* mm = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
|
||||
absl::MakeSpan(temp_outputs), // Relu(X*W1)
|
||||
"relu0"));
|
||||
|
||||
AbstractTensorHandle* hidden = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(ops::MatMul(
|
||||
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
|
||||
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* scores = temp_outputs[0];
|
||||
|
||||
temp_outputs.resize(2);
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
|
||||
"softmaxloss")); // W2*Relu(X*W1)
|
||||
|
||||
AbstractTensorHandle* loss = temp_outputs[0];
|
||||
|
||||
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/{loss},
|
||||
/*sources=*/{W1, W2},
|
||||
/*output_gradients=*/{},
|
||||
outputs.subspan(0, 2)));
|
||||
|
||||
// Only release 2nd temp output as first holds loss values.
|
||||
temp_outputs[1]->Unref();
|
||||
outputs[2] = loss;
|
||||
delete tape;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ScalarMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
return ops::Mul(ctx, inputs, outputs,
|
||||
"scalarMul0"); // Compute eta*A
|
||||
}
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
return ops::MatMul(ctx, inputs, outputs, "matmul0",
|
||||
/*transpose_a=*/false,
|
||||
/*transpose_b=*/false); // Compute X*W1
|
||||
}
|
||||
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry) {
|
||||
return ops::Mul(ctx, inputs, outputs,
|
||||
"mul0"); // Compute x*y
|
||||
}
|
||||
|
||||
// ============================= End Models ================================
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -1,90 +0,0 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
||||
#include <memory>
|
||||
|
||||
#include "absl/types/span.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/gradients.h"
|
||||
#include "tensorflow/c/eager/gradients_internal.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/math_ops.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] + inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status AddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes
|
||||
// y = inputs[0] * inputs[1]
|
||||
// return grad(y, {inputs[0], inputs[1]})
|
||||
Status MatMulGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes 2-layer Neural Network with Softmax Loss.
|
||||
Status MNISTForwardModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Computes MatMul with first matrix tranposed.
|
||||
Status MatMulTransposeModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Test Model to verify Multi-grad functionality for MNIST
|
||||
Status MNISTGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
// Test Model to verify scalar-tensor multiplication Op
|
||||
Status ScalarMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status MatMulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
Status MulModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs,
|
||||
const GradientRegistry& registry);
|
||||
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
|
@ -68,7 +68,7 @@ TEST_P(UnifiedAPI, TestTensorShapeScalar) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
|
||||
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -120,8 +120,8 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) {
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
|
||||
int64_t dim_sizes[] = {2, 4};
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
|
||||
Status s = TestTensorHandleWithDims<float, TF_FLOAT>(ctx.get(), data,
|
||||
dim_sizes, 2, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
@ -130,49 +130,6 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
|
@ -17,6 +17,10 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/c/eager/abstract_context.h"
|
||||
#include "tensorflow/c/eager/abstract_tensor_handle.h"
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/c/tf_tensor.h"
|
||||
#include "tensorflow/core/platform/status.h"
|
||||
|
||||
@ -49,19 +53,38 @@ Status RunModel(Model model, AbstractContext* ctx,
|
||||
|
||||
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
|
||||
|
||||
// Get a Scalar TensorHandle with given float value.
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
|
||||
AbstractTensorHandle** tensor);
|
||||
// Return a tensor handle with given type, values and dimensions.
|
||||
template <class T, TF_DataType datatype>
|
||||
Status TestTensorHandleWithDims(AbstractContext* ctx, const T* data,
|
||||
const int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestTensorHandleWithDims<T, datatype>(eager_ctx, data, dims, num_dims);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Get a Matrix TensorHandle with given float values and dimensions.
|
||||
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
|
||||
// Get a TensorHandle with given int values and dimensions
|
||||
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
|
||||
int64_t* dims, int num_dims,
|
||||
AbstractTensorHandle** tensor);
|
||||
// Return a scalar tensor handle with given value.
|
||||
template <class T, TF_DataType datatype>
|
||||
Status TestScalarTensorHandle(AbstractContext* ctx, const T value,
|
||||
AbstractTensorHandle** tensor) {
|
||||
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
|
||||
TF_NewStatus(), TF_DeleteStatus);
|
||||
TFE_Context* eager_ctx =
|
||||
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
|
||||
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
|
||||
TFE_TensorHandle* input_eager =
|
||||
TestScalarTensorHandle<T, datatype>(eager_ctx, value);
|
||||
*tensor =
|
||||
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Places data from `t` into *result_tensor.
|
||||
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
|
||||
|
@ -190,6 +190,7 @@ tf_cuda_cc_test(
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:nn_ops",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_libtpu(
|
||||
@ -222,3 +223,28 @@ tf_cuda_cc_test(
|
||||
if_true = [],
|
||||
),
|
||||
)
|
||||
|
||||
tf_cuda_cc_test(
|
||||
name = "array_grad_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"array_grad_test.cc",
|
||||
],
|
||||
args = ["--heap_check=local"], # TODO(b/174752220): Remove
|
||||
linkstatic = tf_kernel_tests_linkstatic(),
|
||||
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156,
|
||||
deps = [
|
||||
":grad_test_helper",
|
||||
":array_grad",
|
||||
"//tensorflow/c:tf_status_helper",
|
||||
"//tensorflow/c/eager:c_api_test_util",
|
||||
"//tensorflow/c/experimental/gradients/tape:tape_context",
|
||||
"//tensorflow/c/experimental/ops:array_ops",
|
||||
"//tensorflow/core/platform:tensor_float_32_utils",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
] + if_libtpu(
|
||||
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
|
||||
if_true = [],
|
||||
),
|
||||
)
|
||||
|
133
tensorflow/c/experimental/gradients/array_grad_test.cc
Normal file
133
tensorflow/c/experimental/gradients/array_grad_test.cc
Normal file
@ -0,0 +1,133 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/c/experimental/gradients/array_grad.h"
|
||||
|
||||
#include "tensorflow/c/eager/c_api_test_util.h"
|
||||
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
|
||||
#include "tensorflow/c/eager/unified_api_testutil.h"
|
||||
#include "tensorflow/c/experimental/gradients/grad_test_helper.h"
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/array_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace gradients {
|
||||
namespace internal {
|
||||
namespace {
|
||||
|
||||
using tensorflow::TF_StatusPtr;
|
||||
|
||||
Status IdentityNModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||
TF_RETURN_IF_ERROR(
|
||||
ops::IdentityN(ctx, inputs, absl::MakeSpan(temp_outputs), "IdentityN"));
|
||||
// Although, `ops::IdentityN` returns 2 tensors, the first tensor isn't needed
|
||||
// for computing gradient so we could safely drop it.
|
||||
outputs[0] = temp_outputs[1];
|
||||
temp_outputs[0]->Unref();
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
status_ = StatusFromTF_Status(status.get());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
status_ =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
immediate_execution_ctx_.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
|
||||
// low tolerances
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
|
||||
AbstractContextPtr immediate_execution_ctx_;
|
||||
GradientRegistry registry_;
|
||||
Status status_;
|
||||
|
||||
public:
|
||||
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
|
||||
bool UseFunction() const { return std::get<2>(GetParam()); }
|
||||
};
|
||||
|
||||
TEST_P(CppGradients, TestIdentityNGrad) {
|
||||
// This test is interesting because the current implementation of GradientTape
|
||||
// would return [0, 1] whereas we use build_default_zeros_grads=false here
|
||||
// so we get back [nullptr, 1].
|
||||
|
||||
AbstractTensorHandlePtr x1;
|
||||
{
|
||||
AbstractTensorHandle* x1_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 1.0f, &x1_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x1.reset(x1_raw);
|
||||
}
|
||||
|
||||
AbstractTensorHandlePtr x2;
|
||||
{
|
||||
AbstractTensorHandle* x2_raw = nullptr;
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 1.0f, &x2_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x2.reset(x2_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("IdentityN", IdentityNRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
auto IdentityNGradModel = BuildGradModel(IdentityNModel, registry_);
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
status_ =
|
||||
RunModel(IdentityNGradModel, immediate_execution_ctx_.get(),
|
||||
{x1.get(), x2.get()}, absl::MakeSpan(outputs), UseFunction());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
EXPECT_EQ(outputs[0], nullptr);
|
||||
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {1.0f}, /*dims*/ {},
|
||||
/*abs_error*/ 0));
|
||||
outputs[1]->Unref();
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#else
|
||||
INSTANTIATE_TEST_SUITE_P(
|
||||
UnifiedCAPI, CppGradients,
|
||||
::testing::Combine(::testing::Values("graphdef", "mlir"),
|
||||
/*tfrt*/ ::testing::Values(false),
|
||||
/*use_function*/ ::testing::Values(true, false)));
|
||||
#endif
|
||||
} // namespace
|
||||
} // namespace internal
|
||||
} // namespace gradients
|
||||
} // namespace tensorflow
|
@ -101,7 +101,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
|
||||
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 1.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
|
@ -90,9 +90,9 @@ class CppGradients
|
||||
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
status_ =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
immediate_execution_ctx_.reset(ctx_raw);
|
||||
}
|
||||
|
||||
@ -115,8 +115,8 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -124,12 +124,14 @@ TEST_P(CppGradients, TestAddGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
|
||||
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
|
||||
// AddV2Registerer.
|
||||
status_ = registry_.Register("AddV2", AddRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
@ -142,8 +144,8 @@ TEST_P(CppGradients, TestExpGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -167,8 +169,8 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
AbstractTensorHandlePtr A;
|
||||
{
|
||||
AbstractTensorHandle* A_raw;
|
||||
status_ = TestTensorHandleWithDimsFloat(immediate_execution_ctx_.get(),
|
||||
A_vals, A_dims, 2, &A_raw);
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
A.reset(A_raw);
|
||||
}
|
||||
@ -178,8 +180,8 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
AbstractTensorHandlePtr B;
|
||||
{
|
||||
AbstractTensorHandle* B_raw;
|
||||
status_ = TestTensorHandleWithDimsFloat(immediate_execution_ctx_.get(),
|
||||
B_vals, B_dims, 2, &B_raw);
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
B.reset(B_raw);
|
||||
}
|
||||
@ -204,12 +206,77 @@ TEST_P(CppGradients, TestMatMulGrad) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestMatMulGradManual) {
|
||||
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
|
||||
int64_t A_dims[] = {3, 3};
|
||||
AbstractTensorHandlePtr A;
|
||||
{
|
||||
AbstractTensorHandle* A_raw;
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
A.reset(A_raw);
|
||||
}
|
||||
|
||||
float B_vals[] = {9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f};
|
||||
int64_t B_dims[] = {3, 3};
|
||||
AbstractTensorHandlePtr B;
|
||||
{
|
||||
AbstractTensorHandle* B_raw;
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
B.reset(B_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("MatMul", MatMulRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
bool transpose_a_vals[] = {false, false, true, true};
|
||||
bool transpose_b_vals[] = {false, true, false, true};
|
||||
float dA_vals[4][9] = {{24, 15, 6, 24, 15, 6, 24, 15, 6},
|
||||
{18, 15, 12, 18, 15, 12, 18, 15, 12},
|
||||
{24, 24, 24, 15, 15, 15, 6, 6, 6},
|
||||
{18, 18, 18, 15, 15, 15, 12, 12, 12}};
|
||||
float dB_vals[4][9] = {{12, 12, 12, 15, 15, 15, 18, 18, 18},
|
||||
{12, 15, 18, 12, 15, 18, 12, 15, 18},
|
||||
{6, 6, 6, 15, 15, 15, 24, 24, 24},
|
||||
{6, 15, 24, 6, 15, 24, 6, 15, 24}};
|
||||
|
||||
for (int i{}; i < 4; ++i) {
|
||||
bool transpose_a = transpose_a_vals[i];
|
||||
bool transpose_b = transpose_b_vals[i];
|
||||
Model MatMulModel =
|
||||
[transpose_a, transpose_b](
|
||||
AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) -> Status {
|
||||
return ops::MatMul(ctx, inputs, outputs, "MatMul", transpose_a,
|
||||
transpose_b);
|
||||
};
|
||||
Model MatMulGradModel = BuildGradModel(MatMulModel, registry_);
|
||||
std::vector<AbstractTensorHandle*> outputs(2);
|
||||
status_ =
|
||||
RunModel(MatMulGradModel, immediate_execution_ctx_.get(),
|
||||
{A.get(), B.get()}, absl::MakeSpan(outputs), UseFunction());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], dA_vals[i],
|
||||
/*dims*/ {3, 3},
|
||||
/*abs_error*/ 0));
|
||||
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], dB_vals[i],
|
||||
/*dims*/ {3, 3},
|
||||
/*abs_error*/ 0));
|
||||
outputs[0]->Unref();
|
||||
outputs[1]->Unref();
|
||||
}
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestSqrtGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -226,8 +293,8 @@ TEST_P(CppGradients, TestNegGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -244,8 +311,8 @@ TEST_P(CppGradients, TestSubGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -253,8 +320,8 @@ TEST_P(CppGradients, TestSubGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
@ -271,8 +338,8 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -280,8 +347,8 @@ TEST_P(CppGradients, TestMulGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
@ -298,8 +365,8 @@ TEST_P(CppGradients, TestLog1pGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -321,8 +388,8 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
AbstractTensorHandlePtr x;
|
||||
{
|
||||
AbstractTensorHandle* x_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &x_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
x.reset(x_raw);
|
||||
}
|
||||
@ -330,8 +397,8 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
AbstractTensorHandlePtr y;
|
||||
{
|
||||
AbstractTensorHandle* y_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 2.0f, &y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
y.reset(y_raw);
|
||||
}
|
||||
@ -344,8 +411,8 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
|
||||
AbstractTensorHandlePtr z;
|
||||
{
|
||||
AbstractTensorHandle* z_raw = nullptr;
|
||||
status_ =
|
||||
TestScalarTensorHandle(immediate_execution_ctx_.get(), 0.0f, &z_raw);
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 0.0f, &z_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
z.reset(z_raw);
|
||||
}
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
|
||||
#include "tensorflow/c/experimental/ops/nn_ops.h"
|
||||
#include "tensorflow/c/tf_status_helper.h"
|
||||
#include "tensorflow/core/platform/tensor_float_32_utils.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -35,28 +36,6 @@ Status ReluModel(AbstractContext* ctx,
|
||||
return ops::Relu(ctx, inputs, outputs, "Relu");
|
||||
}
|
||||
|
||||
Status ReluGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("Relu", ReluRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]);
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "ReluGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SparseSoftmaxCrossEntropyWithLogitsModel(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
@ -73,80 +52,38 @@ Status SparseSoftmaxCrossEntropyWithLogitsModel(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status SparseSoftmaxCrossEntropyWithLogitsGradModel(
|
||||
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(
|
||||
registry.Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]); // Watch score.
|
||||
tape.Watch(inputs[1]); // Watch label.
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(2);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
|
||||
tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs),
|
||||
"SparseSoftmaxCrossEntropyWithLogitsGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BiasAddModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd");
|
||||
}
|
||||
|
||||
Status BiasAddGradModel(AbstractContext* ctx,
|
||||
absl::Span<AbstractTensorHandle* const> inputs,
|
||||
absl::Span<AbstractTensorHandle*> outputs) {
|
||||
GradientRegistry registry;
|
||||
TF_RETURN_IF_ERROR(registry.Register("BiasAdd", BiasAddRegisterer));
|
||||
|
||||
Tape tape(/*persistent=*/false);
|
||||
tape.Watch(inputs[0]); // Watch A.
|
||||
tape.Watch(inputs[1]); // Watch Bias.
|
||||
std::vector<AbstractTensorHandle*> temp_outputs(1);
|
||||
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
|
||||
TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs,
|
||||
absl::MakeSpan(temp_outputs), "BiasAddGrad"));
|
||||
|
||||
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
|
||||
/*sources=*/inputs,
|
||||
/*output_gradients=*/{}, outputs));
|
||||
for (auto temp_output : temp_outputs) {
|
||||
temp_output->Unref();
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
class CppGradients
|
||||
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
|
||||
protected:
|
||||
void SetUp() override {
|
||||
TF_StatusPtr status(TF_NewStatus());
|
||||
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
|
||||
Status s = StatusFromTF_Status(status.get());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = StatusFromTF_Status(status.get());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
{
|
||||
AbstractContext* ctx_raw = nullptr;
|
||||
Status s =
|
||||
status_ =
|
||||
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
ctx_.reset(ctx_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
immediate_execution_ctx_.reset(ctx_raw);
|
||||
}
|
||||
|
||||
// Computing numerical gradients with TensorFloat-32 is numerically
|
||||
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
|
||||
// low tolerances
|
||||
enable_tensor_float_32_execution(false);
|
||||
}
|
||||
|
||||
AbstractContextPtr ctx_;
|
||||
AbstractContextPtr immediate_execution_ctx_;
|
||||
GradientRegistry registry_;
|
||||
Status status_;
|
||||
|
||||
public:
|
||||
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
|
||||
@ -154,34 +91,41 @@ class CppGradients
|
||||
};
|
||||
|
||||
TEST_P(CppGradients, TestReluGrad) {
|
||||
status_ = registry_.Register("Relu", ReluRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
auto ReluGradModel = BuildGradModel(ReluModel, registry_);
|
||||
|
||||
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 10.0f, -1.0f};
|
||||
int64_t X_dims[] = {3, 3};
|
||||
AbstractTensorHandlePtr X;
|
||||
{
|
||||
AbstractTensorHandle* X_raw;
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
X.reset(X_raw);
|
||||
}
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
ReluModel, ReluGradModel, ctx_.get(), {X.get()}, UseFunction()));
|
||||
ReluModel, ReluGradModel, immediate_execution_ctx_.get(), {X.get()},
|
||||
UseFunction()));
|
||||
|
||||
// Mathematically, Relu isn't differentiable at `0`. So `gradient_checker`
|
||||
// does not work with it.
|
||||
AbstractTensorHandlePtr Y;
|
||||
{
|
||||
AbstractTensorHandle* Y_raw;
|
||||
Status s = TestScalarTensorHandle(ctx_.get(), 0.0f, &Y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), 0.0f, &Y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
Y.reset(Y_raw);
|
||||
}
|
||||
|
||||
std::vector<AbstractTensorHandle*> outputs(1);
|
||||
auto s = RunModel(ReluGradModel, ctx_.get(), {Y.get()},
|
||||
absl::MakeSpan(outputs), UseFunction());
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = RunModel(ReluGradModel, immediate_execution_ctx_.get(), {Y.get()},
|
||||
absl::MakeSpan(outputs), UseFunction());
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {},
|
||||
/*abs_error*/ 0));
|
||||
outputs[0]->Unref();
|
||||
@ -200,27 +144,31 @@ TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
|
||||
AbstractTensorHandlePtr X;
|
||||
{
|
||||
AbstractTensorHandle* X_raw;
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
X.reset(X_raw);
|
||||
}
|
||||
// Label
|
||||
int Y_vals[] = {1, 0, 1};
|
||||
int32_t Y_vals[] = {1, 0, 1};
|
||||
int64_t Y_dims[] = {3};
|
||||
AbstractTensorHandlePtr Y;
|
||||
{
|
||||
AbstractTensorHandle* Y_raw;
|
||||
Status s =
|
||||
TestTensorHandleWithDimsInt(ctx_.get(), Y_vals, Y_dims, 1, &Y_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestTensorHandleWithDims<int32_t, TF_INT32>(
|
||||
immediate_execution_ctx_.get(), Y_vals, Y_dims, 1, &Y_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
Y.reset(Y_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("SparseSoftmaxCrossEntropyWithLogits",
|
||||
SparseSoftmaxCrossEntropyWithLogitsRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
SparseSoftmaxCrossEntropyWithLogitsModel,
|
||||
SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
|
||||
{X.get(), Y.get()}, UseFunction()));
|
||||
BuildGradModel(SparseSoftmaxCrossEntropyWithLogitsModel, registry_),
|
||||
immediate_execution_ctx_.get(), {X.get(), Y.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
TEST_P(CppGradients, TestBiasAddGrad) {
|
||||
@ -234,9 +182,9 @@ TEST_P(CppGradients, TestBiasAddGrad) {
|
||||
AbstractTensorHandlePtr A;
|
||||
{
|
||||
AbstractTensorHandle* A_raw;
|
||||
Status s =
|
||||
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
A.reset(A_raw);
|
||||
}
|
||||
// Bias
|
||||
@ -245,15 +193,18 @@ TEST_P(CppGradients, TestBiasAddGrad) {
|
||||
AbstractTensorHandlePtr Bias;
|
||||
{
|
||||
AbstractTensorHandle* Bias_raw;
|
||||
Status s = TestTensorHandleWithDimsFloat(ctx_.get(), Bias_vals, Bias_dims,
|
||||
1, &Bias_raw);
|
||||
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
|
||||
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
|
||||
immediate_execution_ctx_.get(), Bias_vals, Bias_dims, 1, &Bias_raw);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
Bias.reset(Bias_raw);
|
||||
}
|
||||
|
||||
status_ = registry_.Register("BiasAdd", BiasAddRegisterer);
|
||||
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
|
||||
|
||||
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
|
||||
BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
|
||||
UseFunction()));
|
||||
BiasAddModel, BuildGradModel(BiasAddModel, registry_),
|
||||
immediate_execution_ctx_.get(), {A.get(), Bias.get()}, UseFunction()));
|
||||
}
|
||||
|
||||
#ifdef PLATFORM_GOOGLE
|
||||
|
@ -141,6 +141,72 @@ AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
|
||||
return inversePermutation(map);
|
||||
}
|
||||
|
||||
/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
|
||||
/// follows a canonical form:
|
||||
///
|
||||
/// * Input dimensions have order: (batch_count, spatial_dims,
|
||||
/// input_channel_count).
|
||||
/// * Filter dimensions have order: (spatial_dims, input_channel_count,
|
||||
/// output_channel_count).
|
||||
/// * Output dimensions have order: (batch_count, spatial_dims,
|
||||
/// output_channel_count).
|
||||
template <typename DimensionNumbersTy>
|
||||
static bool HasCanonicalDimensionNumbers(
|
||||
const DimensionNumbersTy& dimension_numbers) {
|
||||
const int input_spatial_rank =
|
||||
llvm::size(dimension_numbers.input_spatial_dimensions());
|
||||
// The dimensions for input should follow the order of
|
||||
// batch_count, spatial_dims..., input_feature_count.
|
||||
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
|
||||
dimension_numbers.input_feature_dimension().getInt() !=
|
||||
(input_spatial_rank + 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int kernel_spatial_rank =
|
||||
llvm::size(dimension_numbers.kernel_spatial_dimensions());
|
||||
// The dimensions for filter should follow the order of
|
||||
// spatial_dims..., input_feature_count, num_output_feature_count.
|
||||
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
|
||||
kernel_spatial_rank ||
|
||||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
|
||||
(kernel_spatial_rank + 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const int output_spatial_rank =
|
||||
llvm::size(dimension_numbers.output_spatial_dimensions());
|
||||
// The dimensions for output should follow the order of
|
||||
// batch_count, spatial_dims.., output_feature_count.
|
||||
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
|
||||
dimension_numbers.output_feature_dimension().getInt() !=
|
||||
(output_spatial_rank + 1)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
if (input_spatial_rank != output_spatial_rank ||
|
||||
input_spatial_rank != kernel_spatial_rank) {
|
||||
return false;
|
||||
}
|
||||
|
||||
auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin();
|
||||
auto kernel_spatial_dim =
|
||||
dimension_numbers.kernel_spatial_dimensions().begin();
|
||||
auto output_spatial_dim =
|
||||
dimension_numbers.output_spatial_dimensions().begin();
|
||||
// Check spatial dims are ordered correctly.
|
||||
for (int i = 0; i < input_spatial_rank; ++i) {
|
||||
const int dim = i + 1;
|
||||
if ((*input_spatial_dim++).getZExtValue() != dim ||
|
||||
(*output_spatial_dim++).getZExtValue() != dim ||
|
||||
(*kernel_spatial_dim++).getZExtValue() != i) {
|
||||
return false;
|
||||
}
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
template <typename OpTy, bool isLHLO = true>
|
||||
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
|
||||
public:
|
||||
@ -264,61 +330,10 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
|
||||
public:
|
||||
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
// This code has been adapted from IREE's
|
||||
// (https://github.com/google/iree/) mhlo -> linalg conversion.
|
||||
LogicalResult matchAndRewrite(
|
||||
lmhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const final {
|
||||
// Check validity of dimension information.
|
||||
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
|
||||
op.dimension_numbers()) {
|
||||
const int input_spatial_rank =
|
||||
llvm::size(dimension_numbers.input_spatial_dimensions());
|
||||
// The dimensions for input should follow the order of
|
||||
// batch_count, spatial_dims..., input_feature_count.
|
||||
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
|
||||
dimension_numbers.input_feature_dimension().getInt() !=
|
||||
(input_spatial_rank + 1))
|
||||
return failure();
|
||||
|
||||
const int kernel_spatial_rank =
|
||||
llvm::size(dimension_numbers.kernel_spatial_dimensions());
|
||||
// The dimensions for filter should follow the order of
|
||||
// spatial_dims..., input_feature_count, num_output_feature_count.
|
||||
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
|
||||
kernel_spatial_rank ||
|
||||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
|
||||
(kernel_spatial_rank + 1))
|
||||
return failure();
|
||||
|
||||
const int output_spatial_rank =
|
||||
llvm::size(dimension_numbers.output_spatial_dimensions());
|
||||
// The dimensions for output should follow the order of
|
||||
// batch_count, spatial_dims.., output_feature_count.
|
||||
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
|
||||
dimension_numbers.output_feature_dimension().getInt() !=
|
||||
(output_spatial_rank + 1))
|
||||
return failure();
|
||||
|
||||
if (input_spatial_rank != output_spatial_rank ||
|
||||
input_spatial_rank != kernel_spatial_rank)
|
||||
return failure();
|
||||
|
||||
auto input_spatial_dim =
|
||||
dimension_numbers.input_spatial_dimensions().begin();
|
||||
auto kernel_spatial_dim =
|
||||
dimension_numbers.kernel_spatial_dimensions().begin();
|
||||
auto output_spatial_dim =
|
||||
dimension_numbers.output_spatial_dimensions().begin();
|
||||
// Check if spatial dims are ordered correctly.
|
||||
for (int i = 0; i < input_spatial_rank; ++i) {
|
||||
const int dim = i + 1;
|
||||
if ((*input_spatial_dim++).getZExtValue() != dim ||
|
||||
(*output_spatial_dim++).getZExtValue() != dim ||
|
||||
(*kernel_spatial_dim++).getZExtValue() != i)
|
||||
return failure();
|
||||
}
|
||||
}
|
||||
if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
|
||||
|
||||
// TODO: LHS dilation for deconvolution not supported yet.
|
||||
// TODO(jurahul): Window reversal is not supported yet.
|
||||
@ -1432,6 +1447,80 @@ struct PadOpOnTensorsConversion : public OpConversionPattern<mhlo::PadOp> {
|
||||
}
|
||||
};
|
||||
|
||||
/// Converts mhlo.conv operation to linalg named op. This only covers normal
|
||||
/// convolution cases. The op must have canonical dimension numbers. Depthwise
|
||||
/// convolution and pointwise convolution are not handled in the conversion.
|
||||
struct NormalConvOpOnTensorsConversion
|
||||
: public OpConversionPattern<mhlo::ConvOp> {
|
||||
using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
|
||||
|
||||
LogicalResult matchAndRewrite(
|
||||
mhlo::ConvOp op, ArrayRef<Value> args,
|
||||
ConversionPatternRewriter& rewriter) const override {
|
||||
if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
|
||||
if (op.feature_group_count() != 1u) return failure();
|
||||
|
||||
mhlo::ConvOp::Adaptor adaptor(args);
|
||||
Location loc = op.getLoc();
|
||||
Value input = adaptor.lhs();
|
||||
Value filter = adaptor.rhs();
|
||||
auto result_type = op.getResult().getType().cast<ShapedType>();
|
||||
int64_t rank = result_type.getRank();
|
||||
|
||||
// Check if padding is zero.
|
||||
DenseIntElementsAttr padding = op.paddingAttr();
|
||||
if (padding &&
|
||||
(!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
|
||||
return rewriter.notifyMatchFailure(op, "expected no padding");
|
||||
}
|
||||
|
||||
// The output shape is N spatial_dims F.
|
||||
SmallVector<Value, 8> dyn_sizes;
|
||||
for (int64_t i = 0, e = rank - 1; i < e; ++i) {
|
||||
if (!result_type.isDynamicDim(i)) continue;
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i));
|
||||
}
|
||||
if (result_type.isDynamicDim(rank - 1)) {
|
||||
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
|
||||
}
|
||||
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
|
||||
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
|
||||
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
|
||||
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
|
||||
Value zero_tensor =
|
||||
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
|
||||
linalg::LinalgOp res;
|
||||
Attribute strides = op.window_stridesAttr();
|
||||
// TODO(ataei): Only support dilated kernel right now. We need to consider
|
||||
// input dilation for deconvolution cases.
|
||||
Attribute dilations = op.rhs_dilationAttr();
|
||||
switch (rank) {
|
||||
case 3: {
|
||||
res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
|
||||
loc, result_type, ValueRange{input, filter},
|
||||
ValueRange{zero_tensor}, dilations, strides);
|
||||
break;
|
||||
}
|
||||
case 4: {
|
||||
res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
|
||||
loc, result_type, ValueRange{input, filter},
|
||||
ValueRange{zero_tensor}, dilations, strides);
|
||||
break;
|
||||
}
|
||||
case 5: {
|
||||
res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
|
||||
loc, result_type, ValueRange{input, filter},
|
||||
ValueRange{zero_tensor}, dilations, strides);
|
||||
break;
|
||||
}
|
||||
default:
|
||||
return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
|
||||
}
|
||||
rewriter.replaceOp(op, res.getOperation()->getResults());
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
OwningRewritePatternList* patterns) {
|
||||
// clang-format off
|
||||
@ -1656,6 +1745,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
|
||||
linalg::BatchMatmulI32I32I32Op>,
|
||||
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
|
||||
linalg::BatchMatmulOp>,
|
||||
NormalConvOpOnTensorsConversion,
|
||||
ReduceOnTensorsConversion,
|
||||
PadOpOnTensorsConversion>(context);
|
||||
// clang-format on
|
||||
|
@ -1441,3 +1441,171 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
|
||||
// CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
|
||||
// CHECK: linalg.yield %[[PAD]] : f32
|
||||
// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
|
||||
-> tensor<?x?x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 2 : i64,
|
||||
input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
|
||||
kernel_input_feature_dimension = 1 : i64,
|
||||
kernel_output_feature_dimension = 2 : i64,
|
||||
kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 2 : i64,
|
||||
output_spatial_dimensions = dense<[1]> : tensor<1xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<[[0], [0]]> : tensor<2x1xi64>,
|
||||
rhs_dilation = dense<1> : tensor<1xi64>,
|
||||
window_strides = dense<1> : tensor<1xi64>
|
||||
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
return %0 : tensor<?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_1d_input_nwc_filter_wcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
|
||||
-> tensor<?x?x?x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<1> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
|
||||
-> tensor<?x?x?x?x?xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 4 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
|
||||
kernel_input_feature_dimension = 3 : i64,
|
||||
kernel_output_feature_dimension = 4 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 4 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
|
||||
rhs_dilation = dense<1> : tensor<3xi64>,
|
||||
window_strides = dense<1> : tensor<3xi64>
|
||||
} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
return %0 : tensor<?x?x?x?x?xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[C0:.+]] = constant 0 : index
|
||||
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C1:.+]] = constant 1 : index
|
||||
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C2:.+]] = constant 2 : index
|
||||
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C3:.+]] = constant 3 : index
|
||||
// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[C4:.+]] = constant 4 : index
|
||||
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
|
||||
// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
|
||||
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
|
||||
|
||||
// -----
|
||||
|
||||
func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>)
|
||||
-> tensor<1x2x4x3xf32> {
|
||||
%0 = "mhlo.convolution"(%arg0, %arg1) {
|
||||
batch_group_count = 1 : i64,
|
||||
dimension_numbers = {
|
||||
input_batch_dimension = 0 : i64,
|
||||
input_feature_dimension = 3 : i64,
|
||||
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
|
||||
kernel_input_feature_dimension = 2 : i64,
|
||||
kernel_output_feature_dimension = 3 : i64,
|
||||
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
|
||||
output_batch_dimension = 0 : i64,
|
||||
output_feature_dimension = 3 : i64,
|
||||
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
|
||||
},
|
||||
feature_group_count = 1 : i64,
|
||||
padding = dense<0> : tensor<2x2xi64>,
|
||||
rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
|
||||
window_strides = dense<1> : tensor<2xi64>
|
||||
} : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
|
||||
return %0 : tensor<1x2x4x3xf32>
|
||||
}
|
||||
// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid
|
||||
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
|
||||
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
|
||||
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32>
|
||||
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
|
||||
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32>
|
||||
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
|
||||
// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64>
|
||||
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
|
||||
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
|
||||
// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>
|
||||
|
@ -2210,6 +2210,70 @@ static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
|
||||
"UnidirectionalSequenceLSTMOp expected to have two stateful operands");
|
||||
}
|
||||
|
||||
LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes(
|
||||
MLIRContext *, Optional<Location>, ValueRange operands, DictionaryAttr attr,
|
||||
RegionRange, SmallVectorImpl<Type> &inferredReturnTypes) {
|
||||
Value input = operands[0];
|
||||
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
|
||||
Value output_state = operands[18];
|
||||
auto output_state_type =
|
||||
output_state.getType().dyn_cast_or_null<RankedTensorType>();
|
||||
|
||||
if (input_type && input_type.hasRank() && input_type.getRank() != 3) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (output_state_type && output_state_type.hasRank() &&
|
||||
output_state_type.getRank() != 2) {
|
||||
return failure();
|
||||
}
|
||||
|
||||
if (!input_type || !input_type.hasRank() || !output_state_type ||
|
||||
!output_state_type.hasRank()) {
|
||||
// We cannot infer the output shape since we don't know the input shape or
|
||||
// the output state shape. We will set the output shape as unranked.
|
||||
Type result_type;
|
||||
result_type = UnrankedTensorType::get(
|
||||
input.getType().cast<ShapedType>().getElementType());
|
||||
inferredReturnTypes.assign({result_type});
|
||||
return success();
|
||||
}
|
||||
|
||||
// Default to non-time_major.
|
||||
bool time_majored = attr.getNamed("time_major").hasValue()
|
||||
? attr.getNamed("time_major")
|
||||
.getValue()
|
||||
.second.cast<BoolAttr>()
|
||||
.getValue()
|
||||
: false;
|
||||
|
||||
int batch =
|
||||
time_majored ? input_type.getDimSize(1) : input_type.getDimSize(0);
|
||||
int time = time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1);
|
||||
int n_output = output_state_type.getDimSize(1);
|
||||
|
||||
// Build the output shape.
|
||||
SmallVector<int64_t, 3> output_shape;
|
||||
if (time_majored) {
|
||||
output_shape = {time, batch, n_output};
|
||||
} else {
|
||||
output_shape = {batch, time, n_output};
|
||||
}
|
||||
auto result_type =
|
||||
mlir::RankedTensorType::get(output_shape, input_type.getElementType());
|
||||
|
||||
inferredReturnTypes.assign({result_type});
|
||||
return success();
|
||||
}
|
||||
|
||||
bool UnidirectionalSequenceLSTMOp::isCompatibleReturnTypes(ArrayRef<Type> lhs,
|
||||
ArrayRef<Type> rhs) {
|
||||
if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
|
||||
if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
|
||||
return true;
|
||||
}
|
||||
|
||||
//===----------------------------------------------------------------------===//
|
||||
// BidirectionalSequenceLSTMOp
|
||||
//===----------------------------------------------------------------------===//
|
||||
|
@ -3959,7 +3959,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
TFL_OperandHasRank<15, 1>, // output_gate_bias
|
||||
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
|
||||
TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias
|
||||
TFL_StatefulOp]> {
|
||||
TFL_StatefulOp,
|
||||
DeclareOpInterfaceMethods<InferTypeOpInterface>
|
||||
]> {
|
||||
let summary = "Unidirectional sequence lstm operator";
|
||||
|
||||
let description = [{
|
||||
@ -4043,6 +4045,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
|
||||
let extraClassDeclaration = [{
|
||||
// StatefulOpInterface:
|
||||
std::vector<int> GetStatefulOperands() { return {18, 19}; }
|
||||
|
||||
// Compatiable return types check
|
||||
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r);
|
||||
}];
|
||||
}
|
||||
|
||||
|
@ -459,3 +459,40 @@ func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lea
|
||||
// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
|
||||
// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
func @whileLoopWithDynamicTensorList(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<*xf32> {
|
||||
%cst = constant dense<3> : tensor<1xi32>
|
||||
%cst_0 = constant dense<0> : tensor<i32>
|
||||
%cst_1 = constant dense<-1> : tensor<i32>
|
||||
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?xf32>>>
|
||||
%1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond, is_stateless = false} : (tensor<i32>, tensor<!tf.variant<tensor<?x?xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
|
||||
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
|
||||
return %2 : tensor<*xf32>
|
||||
|
||||
// verify tensorlist ops pass through.
|
||||
// CHECK-LABEL: func @whileLoopWithDynamicTensorList
|
||||
// CHECK: "tf.TensorListReserve"
|
||||
// CHECK: "tf.While"
|
||||
// CHECK-SAME: (tensor<i32>, tensor<!tf.variant<tensor<?x?xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
|
||||
// CHECK: "tf.TensorListStack"
|
||||
}
|
||||
|
||||
func @tensorlistWhileBody(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> (tensor<i32>, tensor<!tf.variant>) {
|
||||
%0 = "tf.TensorListLength"(%arg1) : (tensor<!tf.variant>) -> tensor<i32>
|
||||
%1 = "tf.Identity"(%arg1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
|
||||
return %0, %1 : tensor<i32>, tensor<!tf.variant>
|
||||
|
||||
// verify `body` function's signature stays unchanged.
|
||||
// CHECK: func @tensorlistWhileBody(%[[ARG0:.*]]: tensor<i32>, %[[ARG:.*]]: tensor<!tf.variant>) -> (tensor<i32>, tensor<!tf.variant>)
|
||||
}
|
||||
|
||||
func @tensorlistWhileCond(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> tensor<i1> {
|
||||
%cst = constant dense<2> : tensor<i32>
|
||||
%0 = "tf.Less"(%arg0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i1>
|
||||
return %0 : tensor<i1>
|
||||
|
||||
// verify `cond` function's signature stays unchanged.
|
||||
// CHECK: func @tensorlistWhileCond(%[[ARG0:.*]]: tensor<i32>, %[[ARG1:.*]]: tensor<!tf.variant>) -> tensor<i1>
|
||||
}
|
||||
|
@ -74,3 +74,14 @@ func @testConv2dShapeInvalidRanks(%arg0: tensor<1x112x80xf32>, %arg1: tensor<128
|
||||
}
|
||||
}
|
||||
|
||||
// -----
|
||||
|
||||
module attributes {tf.versions = {producer = 888 : i32}} {
|
||||
// CHECK-LABEL: testUnidirectionalSequenceLstmShapeInference
|
||||
func @testUnidirectionalSequenceLstmShapeInference(%arg0: tensor<600 x 10 x 20 x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<40 x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<600 x 40 x f32>, %arg19: tensor<600 x 40 x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x ? x ? x f32> {
|
||||
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600x10x20xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<40xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<600x40xf32>, tensor<600x40xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<600x10x40xf32
|
||||
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600 x 10 x 20 x f32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<40xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<600x40xf32>, tensor<600x40xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<? x ? x ? xf32>
|
||||
return %0 : tensor<? x ? x ? x f32>
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -860,7 +860,8 @@ llvm::SmallSet<int, 4> GetTensorListArgumentsFromWhileOp(TF::WhileOp op) {
|
||||
|
||||
// Changes the function type of `cond_func` and `body_func` for the given While
|
||||
// op.
|
||||
LogicalResult UpdateFunctionTypes(TF::WhileOp op,
|
||||
LogicalResult UpdateFunctionTypes(ConversionPatternRewriter &rewriter,
|
||||
TF::WhileOp op,
|
||||
llvm::SmallSet<int, 4> tensor_list_args) {
|
||||
int func_index = 0;
|
||||
for (FuncOp func : {op.cond_function(), op.body_function()}) {
|
||||
@ -906,13 +907,18 @@ LogicalResult UpdateFunctionTypes(TF::WhileOp op,
|
||||
// Change `func`'s argument type to `unranked_argument_types`. If it
|
||||
// return types contain a `DT_VARIANT`, change it to the unranked type
|
||||
// derived from the corresponding argument.
|
||||
func.setType(FunctionType::get(op.getContext(), updated_argument_types,
|
||||
updated_result_types));
|
||||
|
||||
// Change the argument type for the first block.
|
||||
llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
|
||||
arg.setType(updated_argument_types[arg.getArgNumber()]);
|
||||
rewriter.updateRootInPlace(func, [&] {
|
||||
func.setType(FunctionType::get(op.getContext(), updated_argument_types,
|
||||
updated_result_types));
|
||||
});
|
||||
Region &entry = func.getRegion();
|
||||
TypeConverter::SignatureConversion signature_conversion(
|
||||
entry.getNumArguments());
|
||||
for (auto arg : entry.getArguments()) {
|
||||
signature_conversion.addInputs(
|
||||
arg.getArgNumber(), updated_argument_types[arg.getArgNumber()]);
|
||||
}
|
||||
rewriter.applySignatureConversion(&entry, signature_conversion);
|
||||
}
|
||||
return success();
|
||||
}
|
||||
@ -944,7 +950,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
|
||||
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
|
||||
operands, op.getAttrs());
|
||||
converted.removeAttr("T");
|
||||
(void)UpdateFunctionTypes(converted, tensor_list_args);
|
||||
(void)UpdateFunctionTypes(rewriter, converted, tensor_list_args);
|
||||
|
||||
rewriter.replaceOp(op, converted.getResults());
|
||||
return success();
|
||||
|
@ -162,18 +162,21 @@ std::string ExperimentalConvertSavedModelToMlir(
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status) {
|
||||
std::unordered_set<string> tag_set =
|
||||
absl::StrSplit(tags, ',', absl::SkipEmpty());
|
||||
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
|
||||
tensorflow::MLIRImportOptions import_options;
|
||||
import_options.upgrade_legacy = upgrade_legacy;
|
||||
auto module_or = SavedModelSignatureDefsToMlirImportLite(
|
||||
saved_model_path, tag_set, /*exported_names=*/{}, &context,
|
||||
import_options);
|
||||
saved_model_path, tag_set, absl::Span<std::string>(exported_names),
|
||||
&context, import_options);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
@ -183,9 +186,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
}
|
||||
|
||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status) {
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool lift_variables, bool upgrade_legacy,
|
||||
bool show_debug_info, TF_Status *status) {
|
||||
// Load the saved model into a SavedModelBundle.
|
||||
|
||||
std::unordered_set<string> tag_set =
|
||||
@ -200,12 +203,14 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
}
|
||||
|
||||
// Convert the SavedModelBundle to an MLIR module.
|
||||
|
||||
std::vector<string> exported_names =
|
||||
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
|
||||
mlir::MLIRContext context;
|
||||
tensorflow::MLIRImportOptions import_options;
|
||||
import_options.upgrade_legacy = upgrade_legacy;
|
||||
auto module_or =
|
||||
ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options);
|
||||
ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names),
|
||||
&context, import_options);
|
||||
if (!module_or.status().ok()) {
|
||||
Set_TF_Status_from_Status(status, module_or.status());
|
||||
return "// error";
|
||||
|
@ -69,8 +69,9 @@ std::string ExperimentalConvertSavedModelToMlir(
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool upgrade_legacy, bool show_debug_info, TF_Status *status);
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status);
|
||||
|
||||
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
|
||||
//
|
||||
@ -83,9 +84,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
|
||||
// Returns:
|
||||
// A string of textual MLIR representing the raw imported SavedModel.
|
||||
std::string ExperimentalConvertSavedModelV1ToMlir(
|
||||
const std::string &saved_model_path, const std::string &tags,
|
||||
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
|
||||
TF_Status *status);
|
||||
const std::string &saved_model_path, const std::string &exported_names_str,
|
||||
const std::string &tags, bool lift_variables, bool upgrade_legacy,
|
||||
bool show_debug_info, TF_Status *status);
|
||||
|
||||
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
|
||||
const std::string &pass_pipeline,
|
||||
|
@ -102,11 +102,13 @@ def do_test(create_signature,
|
||||
logging.info('Saved model to: %s', save_model_path)
|
||||
# TODO(b/153507667): Set the following boolean flag once the hoisting
|
||||
# variables logic from SavedModel importer is removed.
|
||||
exported_names = ''
|
||||
lift_variables = False
|
||||
upgrade_legacy = True
|
||||
if use_lite:
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
save_model_path, exported_names,
|
||||
','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
upgrade_legacy, show_debug_info)
|
||||
# We don't strictly need this, but it serves as a handy sanity check
|
||||
# for that API, which is otherwise a bit annoying to test.
|
||||
@ -116,7 +118,8 @@ def do_test(create_signature,
|
||||
show_debug_info)
|
||||
else:
|
||||
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
|
||||
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
save_model_path, exported_names,
|
||||
','.join([tf.saved_model.tag_constants.SERVING]),
|
||||
lift_variables, upgrade_legacy, show_debug_info)
|
||||
|
||||
if canonicalize:
|
||||
|
@ -113,6 +113,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
|
||||
// Run another shape inference pass because resource decomposition might have
|
||||
// created new partial types.
|
||||
pm.addPass(TF::CreateTFShapeInferencePass());
|
||||
// Note that the region-based control-flow produced here still contains
|
||||
// function call ops which get inlined by the subsequent inliner pass.
|
||||
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addPass(CreateTPUClusterCleanupAttributesPass());
|
||||
|
@ -78,7 +78,11 @@ void AddSupportedOpsUsingFolding(MLIRContext* context,
|
||||
OperationName(TF::ConcatOffsetOp::getOperationName(), context),
|
||||
OperationName(TF::EmptyOp::getOperationName(), context),
|
||||
OperationName(TF::ListDiffOp::getOperationName(), context),
|
||||
OperationName(TF::RankOp::getOperationName(), context),
|
||||
OperationName(TF::RangeOp::getOperationName(), context),
|
||||
OperationName(TF::ShapeOp::getOperationName(), context),
|
||||
OperationName(TF::ShapeNOp::getOperationName(), context),
|
||||
OperationName(TF::SizeOp::getOperationName(), context),
|
||||
};
|
||||
|
||||
supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());
|
||||
|
@ -279,7 +279,10 @@ void CreateConvertMlirToXlaHloPipeline(
|
||||
mlir::OpPassManager& pm, llvm::StringRef device_type,
|
||||
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
|
||||
custom_legalization_passes) {
|
||||
// Note that the region-based control-flow produced here still contains
|
||||
// function call ops which get inlined by the subsequent inliner pass.
|
||||
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
|
||||
pm.addPass(mlir::createInlinerPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateDropWhileShapeInvariantPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
|
||||
// The SCCP pass performs constant propagation across the IR, which, for
|
||||
@ -317,6 +320,7 @@ void CreateConvertMlirToXlaHloPipeline(
|
||||
// inside PromoteResourcesToArgs.
|
||||
pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
|
||||
|
||||
pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
|
||||
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
|
||||
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
|
||||
/*tf2xla_fallback_device_type=*/device_type));
|
||||
@ -372,7 +376,7 @@ Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
|
||||
|
||||
if (failed(tf2xla.run(module_op))) {
|
||||
return error_handler.Combine(
|
||||
errors::InvalidArgument("TF to XLA legalization failed"));
|
||||
errors::InvalidArgument("TF to XLA legalization failed: "));
|
||||
}
|
||||
|
||||
if (VLOG_IS_ON(1))
|
||||
|
@ -29,7 +29,7 @@ package_group(
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_legalize_tf_inc_gen",
|
||||
name = "legalize_tf_patterns_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
tbl_outs = [
|
||||
("-gen-rewriters", "transforms/generated_legalize_tf.inc"),
|
||||
@ -48,6 +48,22 @@ gentbl(
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_legalize_tf_passes_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
tbl_outs = [
|
||||
("-gen-pass-decls -name XLA", "transforms/xla_legalize_tf_passes.h.inc"),
|
||||
],
|
||||
tblgen = "@llvm-project//mlir:mlir-tblgen",
|
||||
td_file = "transforms/xla_legalize_tf_passes.td",
|
||||
td_relative_includes = [
|
||||
"../hlo/include",
|
||||
],
|
||||
td_srcs = [
|
||||
"@llvm-project//mlir:PassBaseTdFiles",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "xla_passes_inc_gen",
|
||||
compatible_with = get_compatible_with_cloud(),
|
||||
@ -72,8 +88,8 @@ gentbl(
|
||||
cc_library(
|
||||
name = "xla_passes",
|
||||
srcs = [
|
||||
"transforms/passes_detail.h",
|
||||
"transforms/prepare_for_export.cc",
|
||||
"transforms/xla_passes_detail.h",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
@ -96,19 +112,25 @@ cc_library(
|
||||
"transforms/legalize_tf.cc",
|
||||
"transforms/legalize_tf_communication.cc",
|
||||
"transforms/legalize_tf_control_flow.cc",
|
||||
"transforms/legalize_tf_types.cc",
|
||||
"transforms/xla_legalize_tf_passes_detail.h",
|
||||
],
|
||||
hdrs = [
|
||||
"transforms/passes.h",
|
||||
],
|
||||
deps = [
|
||||
":attribute_importer",
|
||||
":legalize_tf_patterns_inc_gen",
|
||||
":type_to_shape",
|
||||
":xla_legalize_tf_passes_inc_gen",
|
||||
":xla_legalize_tf_with_tf2xla",
|
||||
":xla_passes",
|
||||
"//tensorflow/compiler/mlir/hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
|
||||
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
|
||||
"//tensorflow/compiler/mlir/tensorflow",
|
||||
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
|
||||
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
|
||||
"//tensorflow/compiler/xla:shape_util",
|
||||
"//tensorflow/compiler/xla:xla_data_proto_cc",
|
||||
"//tensorflow/compiler/xla/client:padding",
|
||||
|
54
tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir
Normal file
54
tensorflow/compiler/mlir/xla/tests/legalize-tf-types.mlir
Normal file
@ -0,0 +1,54 @@
|
||||
// RUN: tf-opt -xla-legalize-tf-types %s | FileCheck %s
|
||||
|
||||
func @relu_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
|
||||
// CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
|
||||
// CHECK-NEXT: %[[X:.*]] = "tf.Relu"(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
|
||||
%0 = "tf.Relu"(%arg0) : (tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8>
|
||||
return %0: tensor<1x!tf.qint8>
|
||||
}
|
||||
|
||||
func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1x!tf.qint8>, %arg2: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
|
||||
// CHECK: func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8>
|
||||
// CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ( {
|
||||
// CHECK-NEXT: "tf.Yield"(%arg1) : (tensor<1xi8>) -> ()
|
||||
// CHECK-NEXT: }, {
|
||||
// CHECK-NEXT: "tf.Yield"(%arg2) : (tensor<1xi8>) -> ()
|
||||
// CHECK-NEXT: }) {is_stateless = false} : (tensor<i1>) -> tensor<1xi8>
|
||||
// CHECK-NEXT: return %0 : tensor<1xi8>
|
||||
%0 = "tf.IfRegion"(%arg0) ( {
|
||||
"tf.Yield"(%arg1) : (tensor<1x!tf.qint8>) -> ()
|
||||
}, {
|
||||
"tf.Yield"(%arg2) : (tensor<1x!tf.qint8>) -> ()
|
||||
}) {is_stateless = false} : (tensor<i1>) -> tensor<1x!tf.qint8>
|
||||
return %0 : tensor<1x!tf.qint8>
|
||||
}
|
||||
|
||||
func @id_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
|
||||
// CHECK: func @id_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xi8>
|
||||
return %arg0: tensor<1x!tf.qint8>
|
||||
}
|
||||
|
||||
func @id_qint16(%arg0: tensor<1x!tf.qint16>) -> tensor<1x!tf.qint16> {
|
||||
// CHECK: func @id_qint16(%arg0: tensor<1xi16>) -> tensor<1xi16> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xi16>
|
||||
return %arg0: tensor<1x!tf.qint16>
|
||||
}
|
||||
|
||||
func @id_qint32(%arg0: tensor<1x!tf.qint32>) -> tensor<1x!tf.qint32> {
|
||||
// CHECK: func @id_qint32(%arg0: tensor<1xi32>) -> tensor<1xi32> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xi32>
|
||||
return %arg0: tensor<1x!tf.qint32>
|
||||
}
|
||||
|
||||
func @id_quint8(%arg0: tensor<1x!tf.quint8>) -> tensor<1x!tf.quint8> {
|
||||
// CHECK: func @id_quint8(%arg0: tensor<1xui8>) -> tensor<1xui8> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xui8>
|
||||
return %arg0: tensor<1x!tf.quint8>
|
||||
}
|
||||
|
||||
func @id_quint16(%arg0: tensor<1x!tf.quint16>) -> tensor<1x!tf.quint16> {
|
||||
// CHECK: func @id_quint16(%arg0: tensor<1xui16>) -> tensor<1xui16> {
|
||||
// CHECK-NEXT: return %arg0 : tensor<1xui16>
|
||||
return %arg0: tensor<1x!tf.quint16>
|
||||
}
|
@ -496,7 +496,7 @@ def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featur
|
||||
|
||||
// TODO(hinsu): Lower unsigned and quantized types after supporting
|
||||
// them in GetScalarOfType.
|
||||
def : Pat<(TF_ReluOp AnyRankedTensor:$input),
|
||||
def : Pat<(TF_ReluOp AnyTensor:$input),
|
||||
(HLOClient_BroadcastMaxOp
|
||||
(HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input,
|
||||
(BinBroadcastDimensions $zero, $input)),
|
||||
|
178
tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc
Normal file
178
tensorflow/compiler/mlir/xla/transforms/legalize_tf_types.cc
Normal file
@ -0,0 +1,178 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// The TF dialect uses some TF types that are illegal in the MHLO dialect and
|
||||
// some generic types that are legal in MHLO. This pass legalizes TF types into
|
||||
// types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
|
||||
// Rewrites here should run before TF to MHLO op legalizations are run.
|
||||
// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
|
||||
// rather than its own pass.
|
||||
|
||||
#include "llvm/ADT/SmallVector.h"
|
||||
#include "llvm/ADT/TypeSwitch.h"
|
||||
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
|
||||
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
|
||||
#include "mlir/IR/PatternMatch.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "mlir/Support/LLVM.h" // from @llvm-project
|
||||
#include "mlir/Support/LogicalResult.h" // from @llvm-project
|
||||
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
|
||||
|
||||
#define DEBUG_TYPE "xla-legalize-tf-types"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
namespace {
|
||||
|
||||
bool IsIllegalElementType(Type type) {
|
||||
return type
|
||||
.isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
|
||||
mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
|
||||
}
|
||||
|
||||
Type ToLegalElementType(Type type) {
|
||||
return TypeSwitch<Type, Type>(type)
|
||||
.Case<mlir::TF::Qint8Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 8);
|
||||
})
|
||||
.Case<mlir::TF::Qint16Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 16);
|
||||
})
|
||||
.Case<mlir::TF::Qint32Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(type.getContext(), 32);
|
||||
})
|
||||
.Case<mlir::TF::Quint8Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(
|
||||
type.getContext(), 8,
|
||||
mlir::IntegerType::SignednessSemantics::Unsigned);
|
||||
})
|
||||
.Case<mlir::TF::Quint16Type>([&type](Type) {
|
||||
return mlir::IntegerType::get(
|
||||
type.getContext(), 16,
|
||||
mlir::IntegerType::SignednessSemantics::Unsigned);
|
||||
})
|
||||
.Default([&type](Type) { return type; });
|
||||
}
|
||||
|
||||
// TODO(b/180234863): What's below this line is generic so convert it to a
|
||||
// utility.
|
||||
|
||||
bool IsIllegalType(Type type) {
|
||||
return IsIllegalElementType(getElementTypeOrSelf(type));
|
||||
}
|
||||
|
||||
Type ToLegalType(Type type) {
|
||||
if (IsIllegalElementType(type)) return ToLegalElementType(type);
|
||||
if (auto shaped = type.dyn_cast<ShapedType>()) {
|
||||
Type elem = shaped.getElementType();
|
||||
if (IsIllegalType(elem)) return shaped.clone(ToLegalType(elem));
|
||||
}
|
||||
return type;
|
||||
}
|
||||
|
||||
class TfTypeConverter : public TypeConverter {
|
||||
public:
|
||||
TfTypeConverter() {
|
||||
addConversion([](Type type) -> Type {
|
||||
return IsIllegalType(type) ? ToLegalType(type) : type;
|
||||
});
|
||||
}
|
||||
};
|
||||
|
||||
// An Op is illegal iff it contains an illegalType.
|
||||
class TfTypeConversionTarget : public ConversionTarget {
|
||||
public:
|
||||
explicit TfTypeConversionTarget(MLIRContext &ctx, TfTypeConverter &converter)
|
||||
: ConversionTarget(ctx), converter_(converter) {
|
||||
markUnknownOpDynamicallyLegal();
|
||||
}
|
||||
|
||||
protected:
|
||||
bool isDynamicallyLegal(Operation *op) const override {
|
||||
// The FuncOp type can contain types that the op's operand and result types
|
||||
// do not contain.
|
||||
if (auto func = dyn_cast<FuncOp>(op)) {
|
||||
if (!converter_.isSignatureLegal(func.getType())) return false;
|
||||
}
|
||||
return converter_.isLegal(op);
|
||||
}
|
||||
|
||||
private:
|
||||
TfTypeConverter &converter_;
|
||||
};
|
||||
|
||||
class TfTypePattern : public ConversionPattern {
|
||||
public:
|
||||
TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
|
||||
: ConversionPattern(1, converter, MatchAnyOpTypeTag()) {}
|
||||
|
||||
// The dialect conversion framework will call this matchAndRewrite on each
|
||||
// Operation in the IR tree. This call matchAndRewrite needs to update the
|
||||
// Operation's results and child regions.
|
||||
LogicalResult matchAndRewrite(
|
||||
Operation *op, ArrayRef<Value> operands,
|
||||
ConversionPatternRewriter &rewriter) const override {
|
||||
// Update the results.
|
||||
llvm::SmallVector<Type, 4> new_results;
|
||||
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
|
||||
new_results)))
|
||||
return failure();
|
||||
|
||||
// Update the regions. The dialect conversion framework wants new regions to
|
||||
// be created and updated, rather than updating the old op. Thus we use an
|
||||
// OperationState so we can add regions to the new up.
|
||||
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
|
||||
new_results, op->getAttrs(), op->getSuccessors());
|
||||
for (Region ®ion : op->getRegions()) {
|
||||
Region &new_region = *state.addRegion();
|
||||
rewriter.inlineRegionBefore(region, new_region, new_region.begin());
|
||||
if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
|
||||
return failure();
|
||||
}
|
||||
rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
|
||||
|
||||
return success();
|
||||
}
|
||||
};
|
||||
|
||||
struct LegalizeTfTypesPass
|
||||
: public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
|
||||
void runOnOperation() override;
|
||||
};
|
||||
|
||||
void LegalizeTfTypesPass::runOnOperation() {
|
||||
TfTypeConverter converter;
|
||||
OwningRewritePatternList patterns;
|
||||
patterns.insert<TfTypePattern>(&getContext(), converter);
|
||||
populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
|
||||
TfTypeConversionTarget target(getContext(), converter);
|
||||
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
|
||||
return signalPassFailure();
|
||||
}
|
||||
|
||||
static PassRegistration<LegalizeTfTypesPass> registration(
|
||||
"xla-legalize-tf-types",
|
||||
"Replace TensorFlow types with types that are legal in the MHLO dialect");
|
||||
|
||||
} // namespace
|
||||
|
||||
std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
|
||||
return std::make_unique<LegalizeTfTypesPass>();
|
||||
}
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
@ -49,6 +49,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
|
||||
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
|
||||
llvm::StringRef device_type);
|
||||
|
||||
/// Replaces types that do not exist in MHLO with equivalent types that do
|
||||
/// exist.
|
||||
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTfTypesPass();
|
||||
|
||||
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
|
||||
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
|
||||
OwningRewritePatternList& patterns);
|
||||
|
@ -26,7 +26,7 @@ limitations under the License.
|
||||
#include "mlir/IR/Types.h" // from @llvm-project
|
||||
#include "mlir/Pass/Pass.h" // from @llvm-project
|
||||
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h"
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
|
||||
|
||||
#define DEBUG_TYPE "xla-prepare-for-export"
|
||||
|
||||
|
@ -0,0 +1,36 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Declare passes used in xla_legalize_tf.
|
||||
|
||||
include "mlir/Pass/PassBase.td"
|
||||
|
||||
def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> {
|
||||
let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect";
|
||||
|
||||
let description = [{
|
||||
The TF dialect uses some TF types that are illegal in the MHLO dialect and
|
||||
some generic types that are legal in MHLO. This pass legalizes TF types into
|
||||
types that are legal in MHLO. Rewrites here should run before TF to MHLO op
|
||||
legalizations are run.
|
||||
|
||||
Specifically, this pass replaces each quantized integer type with the
|
||||
corresponding ordinary types. For example, `TF::Qint8Type` is replaced with `i8`
|
||||
everywhere it occurs. Types that are replaced are `TF::Qint8Type`,
|
||||
`TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`.
|
||||
}];
|
||||
|
||||
let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()";
|
||||
}
|
@ -0,0 +1,31 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TF_PASSES_DETAIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TF_PASSES_DETAIL_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.h.inc"
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TF_PASSES_DETAIL_H_
|
31
tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h
Normal file
31
tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h
Normal file
@ -0,0 +1,31 @@
|
||||
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_DETAIL_H_
|
||||
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_DETAIL_H_
|
||||
|
||||
#include "mlir/IR/Dialect.h"
|
||||
#include "mlir/Pass/Pass.h"
|
||||
|
||||
namespace mlir {
|
||||
namespace mhlo {
|
||||
|
||||
#define GEN_PASS_CLASSES
|
||||
#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h.inc"
|
||||
|
||||
} // namespace mhlo
|
||||
} // namespace mlir
|
||||
|
||||
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_DETAIL_H_
|
@ -467,6 +467,80 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Prepares a dynamic shape tensor for broadcast by adding leading 1 dimensions.
|
||||
Status DynamicBroadcast(nvinfer1::ITensor* operand, OpConverterParams* params,
|
||||
nvinfer1::ITensor** output, int broadcasted_nbDims) {
|
||||
int operand_nbDims = operand->getDimensions().nbDims;
|
||||
if (broadcasted_nbDims > operand_nbDims) {
|
||||
if (params->validation_only) return Status::OK();
|
||||
int n_extra_dims = broadcasted_nbDims - operand_nbDims;
|
||||
VLOG(2) << "Dynamic broadcast adding " << n_extra_dims << " leading 1s";
|
||||
TF_RETURN_IF_ERROR(params->converter->DynamicReshape(
|
||||
operand, {std::make_pair(0, operand_nbDims)}, params, output,
|
||||
{n_extra_dims}));
|
||||
} else {
|
||||
*output = operand;
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status BroadcastWeights(std::unique_ptr<TRT_TensorOrWeights>& p,
|
||||
nvinfer1::Dims broadcasted_dims) {
|
||||
if (!p->is_weights()) return errors::Internal("Weight input expected");
|
||||
if (p->GetTrtDims().nbDims != broadcasted_dims.nbDims) {
|
||||
TRT_ShapedWeights weights(p->weights());
|
||||
TF_RETURN_IF_ERROR(weights.SetShape(broadcasted_dims));
|
||||
p = std::make_unique<TRT_TensorOrWeights>(weights);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status ApplyBroadcast(std::unique_ptr<TRT_TensorOrWeights>& operand,
|
||||
nvinfer1::Dims broadcasted_dims,
|
||||
OpConverterParams* params) {
|
||||
if (operand->is_weights()) {
|
||||
TF_RETURN_IF_ERROR(BroadcastWeights(operand, broadcasted_dims));
|
||||
} else {
|
||||
nvinfer1::ITensor* tensor = nullptr;
|
||||
auto is_static_shuffle_compatible = [](nvinfer1::Dims dims) {
|
||||
return std::count(dims.d, dims.d + dims.nbDims, -1) <= 1;
|
||||
};
|
||||
if (is_static_shuffle_compatible(broadcasted_dims)) {
|
||||
TF_RETURN_IF_ERROR(PrepareTensorForShape(
|
||||
params->converter, *operand, broadcasted_dims,
|
||||
params->validation_only, &tensor, params->node_def));
|
||||
} else {
|
||||
TF_RETURN_IF_ERROR(DynamicBroadcast(operand->tensor(), params, &tensor,
|
||||
broadcasted_dims.nbDims));
|
||||
}
|
||||
operand = std::make_unique<TRT_TensorOrWeights>(tensor);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Inserts leading 1 dimensions so that both operands have the same rank.
|
||||
// Note: In implicit batch mode, weights' shape can include an explicit 1 batch
|
||||
// dimension. The broadcasted shape might loose this leading batch dim, because
|
||||
// the broadcasted shape does not include the implicit batch dim.
|
||||
// TODO(tfeher): Other code blocks that use GetTrtBroadcastShape need to be
|
||||
// fixed to use this routine to handle dynamic inputs. Eventually,
|
||||
// GetTrtBroadcastShape should only be used by this routine.
|
||||
Status BroadcastTensors(std::unique_ptr<TRT_TensorOrWeights>& operand_l,
|
||||
std::unique_ptr<TRT_TensorOrWeights>& operand_r,
|
||||
bool check_feasibility, OpConverterParams* params) {
|
||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
|
||||
*operand_l, *operand_r, check_feasibility, params->use_implicit_batch,
|
||||
&broadcasted_dims_l, &broadcasted_dims_r));
|
||||
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
TF_RETURN_IF_ERROR(ApplyBroadcast(operand_l, broadcasted_dims_l, params));
|
||||
TF_RETURN_IF_ERROR(ApplyBroadcast(operand_r, broadcasted_dims_r, params));
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
nvinfer1::ITensor* Converter::CreateConstantLayer(
|
||||
const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) {
|
||||
nvinfer1::Weights trt_weights = weights.GetTrtWeights();
|
||||
@ -731,6 +805,17 @@ Status TRT_ShapedWeights::SetValues(T value) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status TRT_ShapedWeights::SetShape(nvinfer1::Dims dims) {
|
||||
if (this->count() != TrtWeightDimsNumElements(dims)) {
|
||||
VLOG(2) << "Changing shape from "
|
||||
<< tensorflow::tensorrt::DebugString(shape_) << ", to "
|
||||
<< tensorflow::tensorrt::DebugString(dims);
|
||||
return errors::Internal("SetShape would change number of elements");
|
||||
}
|
||||
shape_ = dims;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
size_t TRT_ShapedWeights::size_bytes() const {
|
||||
size_t data_type_size = -1;
|
||||
switch (type_) {
|
||||
@ -1707,14 +1792,18 @@ Status PrepareTensorForShape(Converter* converter,
|
||||
const NodeDef& node_def,
|
||||
absl::optional<int> op_instance) {
|
||||
const nvinfer1::Dims input_dims = input.GetTrtDims();
|
||||
// If one of input_dims and dims doesn't have static shape, it means some of
|
||||
// the dims are unknown or need to be inferred. And we don't do further checks
|
||||
// but rely on the caller to not make mistakes.
|
||||
// Otherwise we do simple check to make sure the total sizes are the same.
|
||||
// The input shape may have -1s for dynamic shape. The target shape may have
|
||||
// 0s representing copy over the corresponding input dimensions. It may also
|
||||
// have at most one -1 representing a dimension value that needs to be
|
||||
// inferred. If none of those special values present, we verify that the total
|
||||
// sizes of the input and output shape are the same.
|
||||
// TODO(tfeher): Verify that the total sizes of the input and output shape are
|
||||
// the same in the present of 0s but no -1 in the target shape.
|
||||
// If an input is a weight, it is going to become a tensor via
|
||||
// CreateConstantLayer. So we can treat it as a tensor for
|
||||
// AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
|
||||
if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
|
||||
if (Prod(dims) > 0 &&
|
||||
AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
|
||||
return errors::InvalidArgument(
|
||||
"Incompatible shapes: ", DebugString(input_dims), " vs. ",
|
||||
DebugString(dims));
|
||||
@ -2700,8 +2789,13 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input,
|
||||
int slice_instance = i * max_num_slices + op_instance_value;
|
||||
// maybe_add_a_dimension(i);
|
||||
if (i < size_for_added_dims.size() && size_for_added_dims[i] >= 0) {
|
||||
nvinfer1::Dims dims{1, {1}};
|
||||
if (size_for_added_dims[i] > 0) {
|
||||
dims.d[0] = size_for_added_dims[i];
|
||||
}
|
||||
TF_RETURN_IF_ERROR(
|
||||
CreateScalarConstant(params, size_for_added_dims[i], &tensor));
|
||||
CreateScalarConstant(params, std::min(size_for_added_dims[i], 1),
|
||||
&tensor, nvinfer1::DataType::kINT32, dims));
|
||||
concat_inputs.push_back(tensor);
|
||||
}
|
||||
if (i < slices.size()) {
|
||||
@ -5422,30 +5516,86 @@ Status ConvertGather(OpConverterParams* params) {
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Converts the input matrix multiplication node to a fully connected (FC) layer
|
||||
// if possible, as the FC layer has more tactics and INT implementations. Set
|
||||
// *converted true if the node is converted. Otherwise, *converted==false and
|
||||
// Status::OK indicates the node can't be converted to an FC layer. An error
|
||||
// Status indicates internal problems during conversion.
|
||||
Status ConvertFullyConnectedHelper(OpConverterParams* params,
|
||||
nvinfer1::ITensor* tensor_a,
|
||||
TRT_ShapedWeights weights_b,
|
||||
bool transpose_b) {
|
||||
// Reshape input to 3D - this will be a no-op unless using int8 precision.
|
||||
auto input_dim = tensor_a->getDimensions();
|
||||
while (input_dim.nbDims < 3) {
|
||||
input_dim.d[input_dim.nbDims++] = 1;
|
||||
TRT_TensorOrWeights input_a,
|
||||
TRT_TensorOrWeights input_b,
|
||||
bool transpose_a, bool transpose_b,
|
||||
bool* converted) {
|
||||
*converted = false;
|
||||
|
||||
if (!(!transpose_a && input_a.is_tensor() && input_b.is_weights())) {
|
||||
VLOG(2) << "Not FC compatible, A must be non transposed tensor, and B "
|
||||
"must be constant.";
|
||||
return Status::OK();
|
||||
}
|
||||
const auto& node_def = params->node_def;
|
||||
|
||||
if (!params->use_implicit_batch && input_b.GetTrtDims().nbDims > 2 &&
|
||||
input_b.GetTrtDims().d[0] != 1) {
|
||||
// Implicit broadcasting, if needed, has already been considered to
|
||||
// transform the inputs and ensure the two operands have the same rank here.
|
||||
// If the inputs have rank >= 3, then d[0] is the explicit batch dimension.
|
||||
// The weight (input_b) must have batch size 1 in implicit batch mode.
|
||||
VLOG(2) << "Not FC compatible, if B has an explicit batch dimension, then "
|
||||
"it must be 1.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
nvinfer1::Dims input_dim = input_a.GetTrtDims();
|
||||
if (input_dim.d[input_dim.nbDims - 1] == -1) {
|
||||
VLOG(2) << "Not FC compatible, last dim of A must be static.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
if (input_dim.nbDims + 2 > nvinfer1::Dims::MAX_DIMS) {
|
||||
VLOG(2) << "Not FC compatible, cannot expand A's shape.";
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Add two trailing 1's because FC layer combines the last three dims.
|
||||
nvinfer1::ITensor* tensor_a = nullptr;
|
||||
nvinfer1::Dims reshape_dim{input_dim.nbDims + 2, {}};
|
||||
// The empty braces initialize the elements of reshap_dim.d to 0. A value 0 in
|
||||
// reshape_dim.d[i] will preserve the i-th dimension value from the shape of
|
||||
// input_a.
|
||||
reshape_dim.d[input_dim.nbDims] = 1;
|
||||
reshape_dim.d[input_dim.nbDims + 1] = 1;
|
||||
const NodeDef& node_def = params->node_def;
|
||||
TF_RETURN_IF_ERROR(PrepareTensorForShape(
|
||||
params->converter, TRT_TensorOrWeights(tensor_a), input_dim,
|
||||
params->converter, input_a, reshape_dim,
|
||||
/*validation_only=*/false, &tensor_a, node_def, /*op_instance=*/0));
|
||||
|
||||
VLOG(2) << "New shape of A " << DebugString(tensor_a->getDimensions());
|
||||
|
||||
TRT_ShapedWeights weights_b = input_b.weights();
|
||||
TRT_ShapedWeights weights_2D(weights_b);
|
||||
if (weights_b.shape_.nbDims > 2) {
|
||||
// Combine first nbDims-1 dims into a single dim, e.g. for a 4D tensor we
|
||||
// transform [N, H, W, C] -> [N*H*W, C].
|
||||
int k = weights_b.shape_.d[weights_b.shape_.nbDims - 1];
|
||||
nvinfer1::Dims dims{2, {static_cast<int>(weights_b.count() / k), k}};
|
||||
TF_RETURN_IF_ERROR(weights_2D.SetShape(dims));
|
||||
}
|
||||
|
||||
// FC layer will transpose weights, so we need to pre-transpose.
|
||||
TRT_ShapedWeights weights(weights_b.TrtDType());
|
||||
TRT_ShapedWeights weights(weights_2D.TrtDType());
|
||||
if (!transpose_b) {
|
||||
weights = params->weight_store->GetTempWeights(weights_b);
|
||||
ReorderCKtoKC(weights_b, &weights);
|
||||
weights = params->weight_store->GetTempWeights(weights_2D);
|
||||
ReorderCKtoKC(weights_2D, &weights);
|
||||
} else {
|
||||
weights = weights_b;
|
||||
weights = weights_2D;
|
||||
}
|
||||
TRT_ShapedWeights biases(weights.TrtDType());
|
||||
const int noutput = weights.shape_.d[0];
|
||||
int k = weights.shape_.d[weights.shape_.nbDims - 1];
|
||||
const int noutput = weights.count() / k;
|
||||
VLOG(2) << "Using fully connected layer with k=" << k
|
||||
<< ", n_output=" << noutput
|
||||
<< " weights shape: " << DebugString(weights.shape_) << " to convert "
|
||||
<< node_def.op();
|
||||
nvinfer1::IFullyConnectedLayer* layer =
|
||||
params->converter->network()->addFullyConnected(
|
||||
*tensor_a, noutput, weights.GetTrtWeights(), biases.GetTrtWeights());
|
||||
@ -5454,14 +5604,19 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params,
|
||||
params->converter->SetLayerName(layer, node_def);
|
||||
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
|
||||
|
||||
// Reshape output to 1D - this will be a no-op unless using int8 precision.
|
||||
// A fully connected layer produces output with two trailing singleton
|
||||
// dimensions. We remove these.
|
||||
auto output_dim = output_tensor->getDimensions();
|
||||
output_dim.nbDims = 1;
|
||||
output_dim.nbDims -= 2;
|
||||
// A zero in output_dim indicates copying the corresponding input dimension
|
||||
// value during reshape.
|
||||
std::fill(output_dim.d, output_dim.d + output_dim.nbDims, 0);
|
||||
TF_RETURN_IF_ERROR(PrepareTensorForShape(
|
||||
params->converter, TRT_TensorOrWeights(output_tensor), output_dim,
|
||||
/*validation_only=*/false, &output_tensor, node_def, /*op_instance=*/1));
|
||||
|
||||
/*validation_only=*/false, &output_tensor, node_def,
|
||||
/*op_instance=*/1));
|
||||
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
|
||||
*converted = true;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
@ -5469,82 +5624,66 @@ Status ConvertMatMulHelper(OpConverterParams* params,
|
||||
TRT_TensorOrWeights input_a,
|
||||
TRT_TensorOrWeights input_b, bool transpose_a,
|
||||
bool transpose_b) {
|
||||
// TODO: ReorderCKtoKC is currently not general enough to transpose weights
|
||||
// that are not 2D.
|
||||
if ((transpose_a && input_a.is_weights() &&
|
||||
input_a.GetTrtDims().nbDims != 2) ||
|
||||
(transpose_b && input_b.is_weights() &&
|
||||
input_b.GetTrtDims().nbDims != 2)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot currently transpose constant input if it is not 2 dimensional");
|
||||
if (params->use_implicit_batch) {
|
||||
// In implicit batch mode we are very limited when can we multiply 2D
|
||||
// matrices. If input_A is a 2D tensor, then nbDims==1 (implicit batch dim
|
||||
// not counted). If A is not transposed and B is weight, then we can convert
|
||||
// this treating A as a batch of vectors. This is the only possibility
|
||||
// to implement MatMul with 2D input in implicit batch mode.
|
||||
if ((input_a.GetTrtDims().nbDims < 2 &&
|
||||
(transpose_a || !input_b.is_weights())) ||
|
||||
(input_b.GetTrtDims().nbDims < 2)) {
|
||||
return errors::InvalidArgument(
|
||||
"MatMul with 2D tensors requires explicit batch mode, or that tensor"
|
||||
" A is not transposed and B is a constant tensor.");
|
||||
}
|
||||
}
|
||||
|
||||
// If A is a tensor, we can only transpose if it is at least 3D in TF,
|
||||
// or TRT will not do the correct transposition.
|
||||
if (transpose_a && input_a.is_tensor() && input_a.GetTrtDims().nbDims < 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
||||
"non-batch dimensions.");
|
||||
}
|
||||
|
||||
// If B is a tensor, then it must be at least 3D in TF,
|
||||
// or TRT won't be able to handle the multiply correctly.
|
||||
if (input_b.is_tensor() && input_b.GetTrtDims().nbDims < 2) {
|
||||
return errors::InvalidArgument(
|
||||
"Second input must either be a constant, or contain at least 2 "
|
||||
"non-batch dimensions.");
|
||||
}
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
// If an FC layer can be used and would be faster, use that instead.
|
||||
const bool can_use_fc =
|
||||
!transpose_a && input_a.is_tensor() && input_b.is_weights();
|
||||
const bool should_use_fc = can_use_fc && input_a.GetTrtDims().nbDims >= 3 &&
|
||||
input_b.GetTrtDims().nbDims == 2;
|
||||
// If int8 is specified, FC must be used unless it is not compatible, as MM
|
||||
// does not support int8 at this time.
|
||||
if (should_use_fc || (can_use_fc && params->converter->precision_mode() ==
|
||||
TrtPrecisionMode::INT8)) {
|
||||
return ConvertFullyConnectedHelper(params, input_a.tensor(),
|
||||
input_b.weights(), transpose_b);
|
||||
}
|
||||
bool converted = false;
|
||||
TF_RETURN_IF_ERROR(ConvertFullyConnectedHelper(
|
||||
params, input_a, input_b, transpose_a, transpose_b, &converted));
|
||||
if (converted == true) return Status::OK();
|
||||
|
||||
const auto get_matrix_op = [](nvinfer1::ITensor* in,
|
||||
bool transpose) -> nvinfer1::MatrixOperation {
|
||||
return (in->getDimensions().nbDims < 2) ? nvinfer1::MatrixOperation::kVECTOR
|
||||
: (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
|
||||
: nvinfer1::MatrixOperation::kNONE;
|
||||
};
|
||||
|
||||
// If the MatMul operand is a constant, applies transposes at conversion-time
|
||||
// as necessary. If the operand is a tensor, does nothing. If required
|
||||
// transposes were applied, sets transpose to false.
|
||||
const auto prepare_matmul_operand =
|
||||
[¶ms](TRT_TensorOrWeights operand,
|
||||
bool* transpose) -> nvinfer1::ITensor* {
|
||||
const auto convert_to_itensor =
|
||||
[¶ms](TRT_TensorOrWeights operand) -> nvinfer1::ITensor* {
|
||||
if (operand.is_tensor()) {
|
||||
return operand.tensor();
|
||||
} else {
|
||||
TRT_ShapedWeights weights(operand.weights().TrtDType());
|
||||
if (*transpose) {
|
||||
weights = params->weight_store->GetTempWeights(operand.weights());
|
||||
ReorderCKtoKC(operand.weights(), &weights);
|
||||
// Weights have been transposed, can set transpose to false
|
||||
*transpose = false;
|
||||
} else {
|
||||
weights = operand.weights();
|
||||
}
|
||||
return params->converter->CreateConstantLayer(weights, weights.shape_);
|
||||
return params->converter->CreateConstantLayer(operand.weights(),
|
||||
operand.GetTrtDims());
|
||||
}
|
||||
};
|
||||
|
||||
nvinfer1::ITensor* tensor_a = prepare_matmul_operand(input_a, &transpose_a);
|
||||
nvinfer1::ITensor* tensor_b = prepare_matmul_operand(input_b, &transpose_b);
|
||||
nvinfer1::ITensor* tensor_a = convert_to_itensor(input_a);
|
||||
nvinfer1::ITensor* tensor_b = convert_to_itensor(input_b);
|
||||
|
||||
const auto get_matrix_op = [](nvinfer1::ITensor* in,
|
||||
bool transpose) -> nvinfer1::MatrixOperation {
|
||||
return (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
|
||||
: nvinfer1::MatrixOperation::kNONE;
|
||||
};
|
||||
nvinfer1::MatrixOperation op_a, op_b;
|
||||
// Note: In implicit batch mode kTRANSPOSE and kNONE are only valid if the
|
||||
// matrix has at least 2 non-batch dimension. In implicit batch mode, if a has
|
||||
// 1 dim (excluding batch dim), then we can only use kVECTOR, which will treat
|
||||
// matrix A as a batch of vectors.
|
||||
op_a = (tensor_a->getDimensions().nbDims < 2)
|
||||
? nvinfer1::MatrixOperation::kVECTOR
|
||||
: get_matrix_op(tensor_a, transpose_a);
|
||||
// In implicit batch mode, if B has only 1 dims (excluding batch dim) then we
|
||||
// already reject the case and don't convert. One could consider using the
|
||||
// kVECTOR flag to express C = MatMul(A, B.T) if A is weight, but the result
|
||||
// will not have the correct shape: in TRT's implicit batch implementation,
|
||||
// the result is a batch of vectors D_ji = A_ik * B_jk, where j is the batch
|
||||
// dimension. In contrast, the TF MatMul op produces C = D.T, and we cannot
|
||||
// transpose over the batch dimension (implicit batch mode).
|
||||
op_b = get_matrix_op(tensor_b, transpose_b);
|
||||
|
||||
nvinfer1::IMatrixMultiplyLayer* layer =
|
||||
params->converter->network()->addMatrixMultiply(
|
||||
*tensor_a, get_matrix_op(tensor_a, transpose_a), *tensor_b,
|
||||
get_matrix_op(tensor_b, transpose_b));
|
||||
params->converter->network()->addMatrixMultiply(*tensor_a, op_a,
|
||||
*tensor_b, op_b);
|
||||
|
||||
const auto& node_def = params->node_def;
|
||||
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
|
||||
@ -5582,44 +5721,45 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
|
||||
" inputs but expected 2, at ",
|
||||
node_def.name());
|
||||
}
|
||||
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
|
||||
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
|
||||
// false}}));
|
||||
TF_RETURN_IF_ERROR(CheckInputsWeights(
|
||||
*params, {{"x", TrtInputArg::kBoth}, {"y", TrtInputArg::kBoth}}));
|
||||
// TODO(tfeher): Consider adding INT8 type because FC layer can support it.
|
||||
TF_RETURN_IF_ERROR(
|
||||
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
|
||||
if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
|
||||
return errors::InvalidArgument(
|
||||
"All inputs are weights, but Grappler is expected to fold them.");
|
||||
}
|
||||
if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor() &&
|
||||
inputs.at(0).GetTrtDims().nbDims != inputs.at(1).GetTrtDims().nbDims) {
|
||||
return errors::Unimplemented(
|
||||
"Inputs must have the same rank if they are both tensors.");
|
||||
}
|
||||
|
||||
TFAttrs attrs(node_def);
|
||||
const bool transpose_a = attrs.get<bool>("adj_x");
|
||||
const bool transpose_b = attrs.get<bool>("adj_y");
|
||||
|
||||
// There is no way to batch constants in TRT. Example:
|
||||
// Tensor with TF Dims: 12 5 3 -> TRT Dims: 5 3
|
||||
// Weight with TF Dims: 12 3 6 -> TRT Dims: 12 3 6
|
||||
// It is not possible to treat the weight input as a batched [3, 6] tensor.
|
||||
// In case input_l is weight, check whether input_l has implicit batch mode
|
||||
// compatible batch dim.
|
||||
const auto check_weight_is_not_batched =
|
||||
[](const TRT_TensorOrWeights& input_l,
|
||||
const TRT_TensorOrWeights& input_r) {
|
||||
// If input_l is a weight, then input_r must be a tensor because
|
||||
// otherwise the op would be handled by Grappler.
|
||||
// There is no way to batch constants in TRT using implicit batch mode.
|
||||
// Example:
|
||||
// Tensor with TF Dims: 12 5 3 -> TRT Dims: 5 3
|
||||
// Weight with TF Dims: 12 3 6 -> TRT Dims: 12 3 6
|
||||
// It is not possible to treat the weight input as a batched [3, 6]
|
||||
// tensor. Batched weight tensors must have batch dim = 1 (after the
|
||||
// broadcast).
|
||||
if (input_l.is_weights() &&
|
||||
input_l.GetTrtDims().nbDims > input_r.GetTrtDims().nbDims &&
|
||||
input_l.GetTrtDims().d[0] != 1) {
|
||||
return errors::Unimplemented(
|
||||
"TensorRT does not support batched constants.");
|
||||
"TensorRT does not support batched constants in implicit batch "
|
||||
"mode.");
|
||||
}
|
||||
return Status::OK();
|
||||
};
|
||||
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(0), inputs.at(1)));
|
||||
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(1), inputs.at(0)));
|
||||
if (params->use_implicit_batch) {
|
||||
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(0), inputs.at(1)));
|
||||
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(1), inputs.at(0)));
|
||||
}
|
||||
|
||||
// Broadcast inputs. We don't check feasibility since the dimensions in a
|
||||
// MatMul don't need to match. For example, consider a valid set of inputs
|
||||
@ -5627,22 +5767,14 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
|
||||
// input 0: [N, T, C]
|
||||
// input 1: [1, C, K]
|
||||
// Since C != K and T != C, check feasiblity would fail.
|
||||
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
|
||||
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
|
||||
inputs.at(0), inputs.at(1), /*check_feasibility=*/false,
|
||||
params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
|
||||
nvinfer1::ITensor* tensor_l = nullptr;
|
||||
nvinfer1::ITensor* tensor_r = nullptr;
|
||||
TF_RETURN_IF_ERROR(
|
||||
PrepareTensorForShape(params->converter, inputs.at(0), broadcasted_dims_l,
|
||||
params->validation_only, &tensor_l, node_def));
|
||||
TF_RETURN_IF_ERROR(
|
||||
PrepareTensorForShape(params->converter, inputs.at(1), broadcasted_dims_r,
|
||||
params->validation_only, &tensor_r, node_def));
|
||||
auto input_l = std::make_unique<TRT_TensorOrWeights>(inputs.at(0));
|
||||
auto input_r = std::make_unique<TRT_TensorOrWeights>(inputs.at(1));
|
||||
TF_RETURN_IF_ERROR(BroadcastTensors(input_l, input_r,
|
||||
/*check_feasibility=*/false, params));
|
||||
|
||||
if (params->validation_only) return Status::OK();
|
||||
|
||||
return ConvertMatMulHelper(params, TRT_TensorOrWeights(tensor_l),
|
||||
TRT_TensorOrWeights(tensor_r), transpose_a,
|
||||
return ConvertMatMulHelper(params, *input_l, *input_r, transpose_a,
|
||||
transpose_b);
|
||||
}
|
||||
|
||||
|
@ -192,6 +192,8 @@ class TRT_ShapedWeights {
|
||||
template <typename T>
|
||||
Status SetValues(T value);
|
||||
|
||||
Status SetShape(nvinfer1::Dims dims);
|
||||
|
||||
int64_t count() const;
|
||||
|
||||
size_t size_bytes() const;
|
||||
@ -558,10 +560,12 @@ class Converter {
|
||||
// This can be achieved by calling DynamicReshape(input, {{2,4},{0,2}},
|
||||
// params).
|
||||
//
|
||||
// Before each slice we can insert a new dim if the corresponding
|
||||
// Before each slice we can insert new dims if the corresponding
|
||||
// size_for_added_dims element is not negative. The size_for_added_dims array
|
||||
// can have more than slices.size() elements, in order to insert a dimension
|
||||
// ater the last slice.
|
||||
// after the last slice. For example, to add two leading 1 dimensions, and
|
||||
// three trailing 1 dimensions, call DynamicReshape(input, {{0,nbDims}},
|
||||
// {2, 3}).
|
||||
//
|
||||
// Parameters:
|
||||
// input - input tensor
|
||||
|
@ -712,8 +712,8 @@ TEST(TrtNodeValidator, IsTensorRTCandidate) {
|
||||
ExpectStatus(
|
||||
validator.IsTensorRTCandidate(incompatible_matmul.operation.node()),
|
||||
error::INVALID_ARGUMENT,
|
||||
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
||||
"non-batch dimensions.");
|
||||
"MatMul with 2D tensors requires explicit batch mode, or that tensor A "
|
||||
"is not transposed and B is a constant tensor.");
|
||||
ExpectStatus(validator.IsTensorRTCandidate(unsupported_op.operation.node()),
|
||||
error::UNIMPLEMENTED, "Op type Erf is not supported");
|
||||
ExpectStatus(validator.IsTensorRTCandidate(
|
||||
@ -2466,87 +2466,138 @@ TEST_P(OpConverter_FP32_Test, ConvertShape) {
|
||||
}
|
||||
}
|
||||
|
||||
// Helper function for testing MatMul and BatchMatMul
|
||||
// get_matmul corresponds to the function used to generate the node. It should
|
||||
// accept (DataType, transpose_a, transpose_b) as parameters.
|
||||
struct MatMulTestParams {
|
||||
std::vector<int> shape_a;
|
||||
std::vector<int> values_a;
|
||||
bool transpose_a;
|
||||
std::vector<int> shape_b;
|
||||
std::vector<int> values_b;
|
||||
bool transpose_b;
|
||||
std::vector<int> expected_shape;
|
||||
std::vector<int> expected_output;
|
||||
};
|
||||
|
||||
// Helper function for testing MatMul and BatchMatMul. get_matmul is a function
|
||||
// used to generate the node. It accepts (DataType, transpose_a, transpose_b) as
|
||||
// parameters.
|
||||
void TestMatMulHelper(
|
||||
OpConverterTest* test,
|
||||
ParameterizedOpConverterTestBase* test,
|
||||
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
|
||||
const std::string& op_name) {
|
||||
// HACK: This needs to be done in a better way.
|
||||
const bool is_batch_matmul = op_name == "BatchMatMul";
|
||||
const std::vector<MatMulTestParams>& params) {
|
||||
{
|
||||
// Unsupported data type.
|
||||
test->Reset();
|
||||
NodeDef node_def = get_matmul(DT_INT32, false, false);
|
||||
test->AddTestTensor("input", {2}, /*batch_size=*/1,
|
||||
nvinfer1::DataType::kINT32);
|
||||
test->AddTestTensor("input", {1, 2}, DT_INT32, {});
|
||||
test->AddTestWeights<int32>("weights", {2, 1}, {3, 5});
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
StrCat("Data type int32 is not supported for ", op_name,
|
||||
StrCat("Data type int32 is not supported for ", node_def.op(),
|
||||
", must be one of [float, half], at my_matmul")
|
||||
.c_str());
|
||||
}
|
||||
// OK.
|
||||
for (bool transpose_a : {false, true}) {
|
||||
for (bool transpose_b : {false, true}) {
|
||||
test->Reset();
|
||||
NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b);
|
||||
test->AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
if (is_batch_matmul) {
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"TensorRT does not support batched constants.");
|
||||
continue;
|
||||
} else if (transpose_a) {
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
||||
"non-batch dimensions");
|
||||
continue;
|
||||
}
|
||||
test->RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
||||
|
||||
const DataVec input_data{{"input", test->AsTensor<float>({0, 1})}};
|
||||
DataVec output_data{{"my_matmul", test->ConstructTensor<float>(2)}};
|
||||
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
|
||||
if (transpose_b) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
||||
} else {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(2, 3));
|
||||
}
|
||||
}
|
||||
// FC conversion depends on whether the last dim of A is known or not. In
|
||||
// Dynamic shape mode, we will check whether A is handled correctly if it has
|
||||
// a partially known input shape (last dim known).
|
||||
std::vector<bool> a_test_partial_shape_values{false};
|
||||
if (test->get_trt_mode() == TrtTestMode::kDynamicShape) {
|
||||
a_test_partial_shape_values.push_back(true);
|
||||
}
|
||||
// OK, 3D inputs
|
||||
for (bool transpose_b : {false, true}) {
|
||||
test->Reset();
|
||||
NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b);
|
||||
test->AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
if (is_batch_matmul) {
|
||||
test->RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"TensorRT does not support batched constants.");
|
||||
continue;
|
||||
}
|
||||
test->RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
|
||||
const DataVec input_data{{"input", test->AsTensor<float>({0, 1})}};
|
||||
DataVec output_data{{"my_matmul", test->ConstructTensor<float>(2)}};
|
||||
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
|
||||
if (transpose_b) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
|
||||
} else {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(2, 3));
|
||||
|
||||
for (auto p : params) {
|
||||
for (bool a_is_tensor : {true, false}) {
|
||||
for (bool b_is_tensor : {true, false}) {
|
||||
for (bool a_partial_shape : a_test_partial_shape_values) {
|
||||
if (a_partial_shape && !a_is_tensor) {
|
||||
// Only tensors can have partial shape.
|
||||
continue;
|
||||
}
|
||||
if (!a_is_tensor && !b_is_tensor) {
|
||||
// Skip test when both args are weights. We do not convert this
|
||||
// since const folding eliminates this case.
|
||||
continue;
|
||||
}
|
||||
SCOPED_TRACE(StrCat("A", p.transpose_a ? ".T" : "", " is ",
|
||||
a_is_tensor ? "tensor" : "weight", ", B",
|
||||
p.transpose_b ? ".T" : "", " is ",
|
||||
b_is_tensor ? "tensor " : "weight, rank A ",
|
||||
p.shape_a.size(), ", rank B ", p.shape_b.size()));
|
||||
test->Reset();
|
||||
|
||||
NodeDef node_def =
|
||||
get_matmul(test->get_tf_type(), p.transpose_a, p.transpose_b);
|
||||
const bool is_batch_matmul = node_def.op() == "BatchMatMul";
|
||||
|
||||
if (a_is_tensor) {
|
||||
if (a_partial_shape) {
|
||||
// Prepare a partial shape for A where only the last dim is known.
|
||||
std::vector<int> partial_shape(p.shape_a.size(), -1);
|
||||
int k = p.shape_a.size() - 1;
|
||||
partial_shape.at(k) = p.shape_a.at(k);
|
||||
test->AddTestTensor("input", p.shape_a, test->get_tf_type(),
|
||||
p.values_a, partial_shape);
|
||||
} else {
|
||||
test->AddTestTensor("input", p.shape_a, p.values_a);
|
||||
}
|
||||
} else {
|
||||
test->AddTestWeights("input", p.shape_a, p.values_a,
|
||||
test->get_tf_type());
|
||||
}
|
||||
if (b_is_tensor) {
|
||||
if (a_is_tensor && p.shape_a[0] != p.shape_b[0] &&
|
||||
test->get_trt_mode() == TrtTestMode::kImplicitBatch) {
|
||||
VLOG(2) << "Skipping test with inpcompatible batch dimensions";
|
||||
continue;
|
||||
}
|
||||
test->AddTestTensor("weights", p.shape_b, p.values_b);
|
||||
} else {
|
||||
test->AddTestWeights("weights", p.shape_b, p.values_b,
|
||||
test->get_tf_type());
|
||||
}
|
||||
|
||||
Status conversion_status = Status::OK();
|
||||
if (test->get_trt_mode() == TrtTestMode::kImplicitBatch) {
|
||||
// Implicit batch mode has several restriction. We change conversion
|
||||
// status accordingly.
|
||||
if (is_batch_matmul) {
|
||||
if (a_is_tensor && p.shape_a.size() < p.shape_b.size()) {
|
||||
conversion_status = errors::InvalidArgument(
|
||||
"Broadcasting beyond batch dimension is not supported "
|
||||
"(tensor #dims ",
|
||||
p.shape_a.size(), " vs broadcast #dims ", p.shape_b.size(),
|
||||
")");
|
||||
}
|
||||
if (b_is_tensor && p.shape_b.size() < p.shape_a.size()) {
|
||||
conversion_status = errors::InvalidArgument(
|
||||
"Broadcasting beyond batch dimension is not supported "
|
||||
"(tensor #dims ",
|
||||
p.shape_b.size(), " vs broadcast #dims ", p.shape_a.size(),
|
||||
")");
|
||||
}
|
||||
if ((!a_is_tensor || !b_is_tensor) && p.shape_a[0] != 1) {
|
||||
conversion_status = errors::Unimplemented(
|
||||
"TensorRT does not support batched constants in implicit "
|
||||
"batch mode.");
|
||||
}
|
||||
} else if ((a_is_tensor && p.shape_a.size() <= 2 &&
|
||||
(p.transpose_a || b_is_tensor)) ||
|
||||
(b_is_tensor && p.shape_b.size() <= 2)) {
|
||||
conversion_status = errors::InvalidArgument(
|
||||
"MatMul with 2D tensors requires explicit batch mode, or that"
|
||||
" tensor A is not transposed and B is a constant tensor.");
|
||||
}
|
||||
}
|
||||
|
||||
test->TestOpConverter("my_matmul", node_def, p.expected_shape,
|
||||
conversion_status, Status::OK(),
|
||||
ElementsAreArray(p.expected_output));
|
||||
if (!conversion_status.ok()) {
|
||||
VLOG(2) << "Converted with status " << conversion_status;
|
||||
}
|
||||
VLOG(2) << "== Finished test iteration ==";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2563,7 +2614,39 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_found) {
|
||||
EXPECT_EQ(expect_found, layer_found);
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertMatMul) {
|
||||
std::vector<MatMulTestParams> GetMatMulTestParams() {
|
||||
std::vector<MatMulTestParams> params{
|
||||
// clang-format off
|
||||
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, false, // A (shape, val, T?)
|
||||
{2, 2}, {0, 1, 2, 3}, false, // B (shape, val, T?)
|
||||
{2, 2}, {2, 3, 6, 11}}, // result (shape, val)
|
||||
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, false,
|
||||
{2, 2}, {0, 1, 2, 3}, true,
|
||||
{2, 2}, {1, 3, 3, 13}},
|
||||
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, true,
|
||||
{2, 2}, {0, 1, 2, 3}, false,
|
||||
{2, 2}, {4, 6, 6, 10}},
|
||||
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, true,
|
||||
{2, 2}, {0, 1, 2, 3}, true,
|
||||
{2, 2}, {2, 6, 3, 11}},
|
||||
MatMulTestParams{{2, 3}, {0, 1, 2, 3, 4, 5}, false,
|
||||
{2, 3}, {1, 2, 3, 4, 5, 6}, true,
|
||||
{2, 2}, {8, 17, 26, 62}},
|
||||
MatMulTestParams{{2, 3}, {0, 1, 2, 3, 4, 5}, true,
|
||||
{2, 3}, {1, 2, 3, 4, 5, 6}, false,
|
||||
{3, 3}, {12, 15, 18, 17, 22, 27, 22, 29, 36}},
|
||||
MatMulTestParams{{3, 2}, {0, 1, 2, 3, 4, 5}, false,
|
||||
{2, 3}, {1, 2, 3, 4, 5, 6}, false,
|
||||
{3, 3}, {4, 5, 6, 14, 19, 24, 24, 33, 42}},
|
||||
MatMulTestParams{{3, 2}, {0, 1, 2, 3, 4, 5}, true,
|
||||
{2, 3}, {1, 2, 3, 4, 5, 6}, true,
|
||||
{2, 2}, {16, 34, 22, 49}},
|
||||
// clang-format on
|
||||
};
|
||||
return params;
|
||||
}
|
||||
|
||||
TEST_P(OpConverter_FP32_Test, ConvertMatMul) {
|
||||
// Get the NodeDef for MatMul.
|
||||
auto get_matmul_nodedef = [](DataType dtype, bool transpose_a,
|
||||
bool transpose_b) -> NodeDef {
|
||||
@ -2577,64 +2660,10 @@ TEST_F(OpConverterTest, ConvertMatMul) {
|
||||
return matmul.operation.node()->def();
|
||||
};
|
||||
|
||||
// Additional test cases specific to MatMul
|
||||
{
|
||||
// Can only transpose A if it is 2D in TRT
|
||||
Reset();
|
||||
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, true, false);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Cannot transpose first input if it is a tensor with fewer than 2 "
|
||||
"non-batch dimensions.");
|
||||
}
|
||||
{
|
||||
// B must always have 2 non-batch dimensions
|
||||
Reset();
|
||||
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false);
|
||||
AddTestTensor("input", {2}, /*batch_size=*/1);
|
||||
AddTestTensor("weights", {2}, /*batch_size=*/1);
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Second input must either be a constant, or contain at least 2 "
|
||||
"non-batch dimensions.");
|
||||
}
|
||||
{
|
||||
// We can never transpose weights that are not 2D.
|
||||
Reset();
|
||||
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, true, false);
|
||||
AddTestWeights<float>("input", {1, 1, 2}, {0, 1});
|
||||
AddTestTensor("weights", {2, 2}, /*batch_size=*/1);
|
||||
RunValidationAndConversion(
|
||||
node_def, error::INVALID_ARGUMENT,
|
||||
"Cannot currently transpose constant input if it is not 2 dimensional");
|
||||
}
|
||||
{
|
||||
// Make sure that INT8 mode uses IFullyConnectedLayer when possible.
|
||||
Reset(TrtPrecisionMode::INT8);
|
||||
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false);
|
||||
AddTestTensor("input", {2, 1, 1});
|
||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
RunValidationAndConversion(node_def);
|
||||
CheckAddedLayers<nvinfer1::IMatrixMultiplyLayer>(this, false);
|
||||
CheckAddedLayers<nvinfer1::IFullyConnectedLayer>(this, true);
|
||||
}
|
||||
{
|
||||
// Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not
|
||||
// compatible. In this case we can't use FC because weights is a tensor.
|
||||
Reset(TrtPrecisionMode::INT8);
|
||||
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false);
|
||||
AddTestTensor("input", {2, 1, 1});
|
||||
AddTestTensor("weights", {2, 2});
|
||||
RunValidationAndConversion(node_def);
|
||||
CheckAddedLayers<nvinfer1::IMatrixMultiplyLayer>(this, true);
|
||||
CheckAddedLayers<nvinfer1::IFullyConnectedLayer>(this, false);
|
||||
}
|
||||
TestMatMulHelper(this, get_matmul_nodedef, "MatMul");
|
||||
TestMatMulHelper(this, get_matmul_nodedef, GetMatMulTestParams());
|
||||
}
|
||||
|
||||
TEST_F(OpConverterTest, ConvertBatchMatMul) {
|
||||
TEST_P(OpConverter_FP32_Test, ConvertBatchMatMul) {
|
||||
// Get the NodeDef for BatchMatMul.
|
||||
auto get_batch_matmul_nodedef = [](DataType dtype, bool transpose_a,
|
||||
bool transpose_b) -> NodeDef {
|
||||
@ -2648,61 +2677,93 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) {
|
||||
return matmul.operation.node()->def();
|
||||
};
|
||||
|
||||
{
|
||||
// Can't broadcast two tensor inputs of different rank.
|
||||
Reset();
|
||||
NodeDef node_def = get_batch_matmul_nodedef(DT_FLOAT, false, false);
|
||||
AddTestTensor("input", {1, 2, 2}, /*batch_size=*/2);
|
||||
AddTestTensor("weights", {2}, /*batch_size=*/2);
|
||||
RunValidationAndConversion(
|
||||
node_def, error::UNIMPLEMENTED,
|
||||
"Inputs must have the same rank if they are both tensors.");
|
||||
}
|
||||
{
|
||||
// Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not
|
||||
// compatible. In this case we can't use FC because transpose_a is true.
|
||||
Reset(TrtPrecisionMode::INT8);
|
||||
NodeDef node_def = get_batch_matmul_nodedef(DT_FLOAT, true, false);
|
||||
AddTestTensor("input", {1, 2, 2});
|
||||
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
|
||||
RunValidationAndConversion(node_def);
|
||||
CheckAddedLayers<nvinfer1::IMatrixMultiplyLayer>(this, true);
|
||||
CheckAddedLayers<nvinfer1::IFullyConnectedLayer>(this, false);
|
||||
}
|
||||
// We derive test data from the MatMul test params by adding extra leading
|
||||
// dimensions.
|
||||
std::vector<MatMulTestParams> params_2d = GetMatMulTestParams();
|
||||
std::vector<MatMulTestParams> params;
|
||||
params.reserve(params_2d.size() * 3 + 1);
|
||||
|
||||
for (bool transpose_a : {false, true}) {
|
||||
for (bool transpose_b : {false, true}) {
|
||||
Reset();
|
||||
NodeDef node_def =
|
||||
get_batch_matmul_nodedef(DT_FLOAT, transpose_a, transpose_b);
|
||||
AddTestTensor("input", {2, 2}, /*batch_size=*/1);
|
||||
AddTestWeights<float>("weights", {1, 2, 2}, {1, 2, 3, 4});
|
||||
auto insert_ones = [](std::vector<int> v, int n) {
|
||||
std::vector<int> ones(n, 1);
|
||||
ones.insert(ones.end(), v.begin(), v.end());
|
||||
return ones;
|
||||
};
|
||||
|
||||
RunValidationAndConversion(node_def);
|
||||
TRT_TensorOrWeights output;
|
||||
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
|
||||
ASSERT_TRUE(output.is_tensor());
|
||||
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
|
||||
const DataVec input_data{{"input", AsTensor<float>({0, 1, 2, 3})}};
|
||||
DataVec output_data{{"my_matmul", ConstructTensor<float>(4)}};
|
||||
TF_EXPECT_OK(BuildAndRun(input_data, &output_data));
|
||||
if (!transpose_a && !transpose_b) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
|
||||
ElementsAre(3, 4, 11, 16));
|
||||
} else if (transpose_a && transpose_b) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
|
||||
ElementsAre(4, 8, 7, 15));
|
||||
} else if (transpose_a) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
|
||||
ElementsAre(6, 8, 10, 14));
|
||||
} else if (transpose_b) {
|
||||
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
|
||||
ElementsAre(2, 4, 8, 18));
|
||||
}
|
||||
}
|
||||
}
|
||||
// Add a leading 1 dimension to A, B and result.
|
||||
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
|
||||
[](MatMulTestParams p) {
|
||||
p.shape_a.insert(p.shape_a.begin(), 1);
|
||||
p.shape_b.insert(p.shape_b.begin(), 1);
|
||||
p.expected_shape.insert(p.expected_shape.begin(), 1);
|
||||
return p;
|
||||
});
|
||||
|
||||
TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul");
|
||||
// Test with N > 1: weights cannot be batched in implicit batch mode.
|
||||
// clang-format off
|
||||
params.push_back(
|
||||
MatMulTestParams{{2, 2, 2}, {0, 1, 2, 3, 0, 1, 2, 3}, false, // A
|
||||
{2, 2, 2}, {0, 1, 2, 3, 0, 1, 2, 3}, false, // B
|
||||
{2, 2, 2}, {2, 3, 6, 11, 2, 3, 6, 11}} // result
|
||||
);
|
||||
|
||||
params.push_back(
|
||||
MatMulTestParams{{2, 2, 3}, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5},
|
||||
false,
|
||||
{2, 2, 3}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, true,
|
||||
{2, 2, 2}, {8, 17, 26, 62, 8, 17, 26, 62}});
|
||||
// clang-format on
|
||||
|
||||
// Add two leading 1 dimensions to A, B and result.
|
||||
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
|
||||
[insert_ones](MatMulTestParams p) {
|
||||
p.shape_a = insert_ones(p.shape_a, 2);
|
||||
p.shape_b = insert_ones(p.shape_b, 2);
|
||||
p.expected_shape = insert_ones(p.expected_shape, 2);
|
||||
return p;
|
||||
});
|
||||
|
||||
// Test broadcast: add two leading 1 dimensions to A, but not to B.
|
||||
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
|
||||
[insert_ones](MatMulTestParams p) {
|
||||
p.shape_a = insert_ones(p.shape_a, 2);
|
||||
p.expected_shape = insert_ones(p.expected_shape, 2);
|
||||
return p;
|
||||
});
|
||||
|
||||
// Test broadcast: add a leading 1 dimension to A and two leading 1s to B.
|
||||
// Broadcasting A need a dynamic brodacast which will be incompatible with
|
||||
// FC layer.
|
||||
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
|
||||
[insert_ones](MatMulTestParams p) {
|
||||
p.shape_a = insert_ones(p.shape_a, 1);
|
||||
p.shape_b = insert_ones(p.shape_b, 2);
|
||||
p.expected_shape = insert_ones(p.expected_shape, 2);
|
||||
return p;
|
||||
});
|
||||
|
||||
// Test with N > 1: since weights cannot be batched in implicit batch mode.
|
||||
// We tests with batch size 2.
|
||||
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
|
||||
[insert_ones](MatMulTestParams p) {
|
||||
p.shape_a.insert(p.shape_a.begin(), 2);
|
||||
p.values_a.reserve(p.values_a.size() * 2);
|
||||
p.values_a.insert(p.values_a.end(), p.values_a.begin(),
|
||||
p.values_a.end());
|
||||
|
||||
p.shape_b.insert(p.shape_b.begin(), 2);
|
||||
p.values_b.reserve(p.values_b.size() * 2);
|
||||
p.values_b.insert(p.values_b.end(), p.values_b.begin(),
|
||||
p.values_b.end());
|
||||
|
||||
p.expected_shape.insert(p.expected_shape.begin(), 2);
|
||||
p.expected_output.reserve(p.expected_output.size() * 2);
|
||||
p.expected_output.insert(p.expected_output.end(),
|
||||
p.expected_output.begin(),
|
||||
p.expected_output.end());
|
||||
return p;
|
||||
});
|
||||
|
||||
TestMatMulHelper(this, get_batch_matmul_nodedef, params);
|
||||
}
|
||||
|
||||
TEST_P(OpConverter_FP32_FP16_Test, ConvertBiasAdd) {
|
||||
|
@ -247,7 +247,9 @@ Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine,
|
||||
bool status = output_tensor->CopyFrom(*output_tensor, output_shape);
|
||||
if (!status) {
|
||||
return errors::Internal(
|
||||
"Buffer size do not match while reshaping output tensors");
|
||||
"Buffer size (", output_tensor->NumElements(),
|
||||
") do not match while reshaping output tensors to shape ",
|
||||
output_shape.DebugString());
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -110,13 +110,14 @@ namespace {
|
||||
|
||||
// Log just once by default (on default log level), and let the user adjust
|
||||
// the log level for more detailed logging.
|
||||
void LogAtLeastOnce(const std::string& log_message) {
|
||||
if (VLOG_IS_ON(1)) {
|
||||
VLOG(1) << log_message;
|
||||
} else {
|
||||
LOG_FIRST_N(INFO, 1) << log_message;
|
||||
#define LOG_AT_LEAST_ONCE(log_message) \
|
||||
{ \
|
||||
if (VLOG_IS_ON(1)) { \
|
||||
VLOG(1) << log_message; \
|
||||
} else { \
|
||||
LOG_FIRST_N(INFO, 1) << log_message; \
|
||||
} \
|
||||
}
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
@ -132,18 +133,19 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
|
||||
// based on the devices in the module.
|
||||
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
|
||||
MlirOptimizationPassState::Disabled) {
|
||||
LogAtLeastOnce("Skipping MLIR TPU Bridge, session flag not enabled");
|
||||
LOG_AT_LEAST_ONCE("Skipping MLIR TPU Bridge, session flag not enabled");
|
||||
mlir_bridge_gauge_v2->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
|
||||
if (!HasTPUDevicesAndOps(module)) {
|
||||
LogAtLeastOnce("Skipping MLIR TPU Bridge, no TPU devices or TPU ops found");
|
||||
LOG_AT_LEAST_ONCE(
|
||||
"Skipping MLIR TPU Bridge, no TPU devices or TPU ops found");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
LogAtLeastOnce("Running MLIR TPU Bridge");
|
||||
LOG_AT_LEAST_ONCE("Running MLIR TPU Bridge");
|
||||
|
||||
mlir_bridge_gauge_v2->GetCell()->Set(true);
|
||||
TF_RETURN_IF_ERROR(
|
||||
@ -176,7 +178,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||
// based on the devices in the module.
|
||||
if (!IsEnabled(/*device_set=*/nullptr, options.session_options->config,
|
||||
**options.graph)) {
|
||||
LogAtLeastOnce(
|
||||
LOG_AT_LEAST_ONCE(
|
||||
"Skipping MLIR TPU Bridge V1 Compat, session flag not enabled");
|
||||
mlir_bridge_gauge_v1->GetCell()->Set(false);
|
||||
return Status::OK();
|
||||
@ -184,12 +186,12 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
|
||||
|
||||
// Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
|
||||
if (!HasTPUDevicesAndOps(module)) {
|
||||
LogAtLeastOnce(
|
||||
LOG_AT_LEAST_ONCE(
|
||||
"Skipping MLIR TPU Bridge V1 Compat, no TPU devices or TPU ops found");
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
LogAtLeastOnce("Running MLIR TPU Bridge V1 Compat");
|
||||
LOG_AT_LEAST_ONCE("Running MLIR TPU Bridge V1 Compat");
|
||||
|
||||
mlir_bridge_gauge_v1->GetCell()->Set(true);
|
||||
TF_RETURN_IF_ERROR(
|
||||
|
@ -277,9 +277,23 @@ def conv(lhs,
|
||||
precision_config_proto = ""
|
||||
if precision_config:
|
||||
precision_config_proto = precision_config.SerializeToString()
|
||||
needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
|
||||
return gen_xla_ops.xla_conv_v2(
|
||||
if needs_v2:
|
||||
return gen_xla_ops.xla_conv_v2(
|
||||
lhs,
|
||||
rhs,
|
||||
window_strides=window_strides,
|
||||
padding=padding,
|
||||
lhs_dilation=lhs_dilation,
|
||||
rhs_dilation=rhs_dilation,
|
||||
feature_group_count=feature_group_count,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
precision_config=precision_config_proto,
|
||||
preferred_element_type=preferred_element_type,
|
||||
name=name)
|
||||
return gen_xla_ops.xla_conv(
|
||||
lhs,
|
||||
rhs,
|
||||
window_strides=window_strides,
|
||||
@ -289,7 +303,6 @@ def conv(lhs,
|
||||
feature_group_count=feature_group_count,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
precision_config=precision_config_proto,
|
||||
preferred_element_type=preferred_element_type,
|
||||
name=name)
|
||||
|
||||
|
||||
@ -309,14 +322,22 @@ def dot_general(lhs,
|
||||
precision_config_proto = ""
|
||||
if precision_config:
|
||||
precision_config_proto = precision_config.SerializeToString()
|
||||
needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
|
||||
if preferred_element_type is None:
|
||||
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
|
||||
return gen_xla_ops.xla_dot_v2(
|
||||
if needs_v2:
|
||||
return gen_xla_ops.xla_dot_v2(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
precision_config=precision_config_proto,
|
||||
preferred_element_type=preferred_element_type,
|
||||
name=name)
|
||||
return gen_xla_ops.xla_dot(
|
||||
lhs,
|
||||
rhs,
|
||||
dimension_numbers=dimension_numbers.SerializeToString(),
|
||||
precision_config=precision_config_proto,
|
||||
preferred_element_type=preferred_element_type,
|
||||
name=name)
|
||||
|
||||
|
||||
|
@ -804,9 +804,12 @@ Status XlaCompiler::CompileFunction(
|
||||
}
|
||||
|
||||
VLOG(1) << "====================================================";
|
||||
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
|
||||
*graph, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
|
||||
MlirBridgeRolloutPolicy policy = MlirBridgeRolloutPolicy::kDisabledByUser;
|
||||
if (options.is_entry_computation) {
|
||||
policy = GetMlirBridgeRolloutPolicy(
|
||||
*graph, config_proto,
|
||||
/*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
|
||||
}
|
||||
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
|
||||
VLOG(1) << "Using MLIR bridge";
|
||||
GraphDebugInfo debug_info;
|
||||
|
@ -89,10 +89,16 @@ class XlaOpKernelContext {
|
||||
// xla::PRIMITIVE_TYPE_INVALID.
|
||||
xla::PrimitiveType InputXlaType(absl::string_view name);
|
||||
|
||||
// Returns the shape of input `index`.
|
||||
// Returns the shape of input at `index` or input the given `name`. Note that
|
||||
// in case the shape of the input is not static, then the returned shape has
|
||||
// bounds as the dimension size instead of having unknown dimensions. Use
|
||||
// InputXlaShape instead that provides shapes with dynamism information.
|
||||
//
|
||||
ABSL_DEPRECATED(
|
||||
"Prefer InputXlaShape which handles dynamic shapes accurately.")
|
||||
TensorShape InputShape(int index);
|
||||
|
||||
// Returns the shape of input with name `name`.
|
||||
ABSL_DEPRECATED(
|
||||
"Prefer InputXlaShape which handles dynamic shapes accurately.")
|
||||
TensorShape InputShape(absl::string_view name);
|
||||
|
||||
// Returns input `index` as a XlaOp. Unlike
|
||||
|
@ -130,7 +130,7 @@ LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
|
||||
}
|
||||
for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
|
||||
if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
|
||||
*argument_shapes[i], /*minor_to_major_only=*/true)) {
|
||||
*argument_shapes[i])) {
|
||||
return InvalidParameterArgument(
|
||||
executable_.get(), i,
|
||||
"Argument does not match host shape or layout of computation "
|
||||
@ -175,7 +175,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
argument_shapes.reserve(arguments.size());
|
||||
for (const ShapedBuffer* const arg : arguments) {
|
||||
argument_shapes.push_back(&arg->on_device_shape());
|
||||
argument_shapes.push_back(&arg->on_host_shape());
|
||||
}
|
||||
return AsyncCallAndBlockHostUntilDone<xla::ScopedShapedBuffer>(
|
||||
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
|
||||
@ -188,7 +188,7 @@ StatusOr<ExecutionOutput> LocalExecutable::Run(
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
argument_shapes.reserve(arguments.size());
|
||||
for (const ExecutionInput& arg : arguments) {
|
||||
argument_shapes.push_back(&arg.shape());
|
||||
argument_shapes.push_back(&arg.host_shape());
|
||||
}
|
||||
return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
|
||||
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
|
||||
@ -243,7 +243,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
argument_shapes.reserve(arguments.size());
|
||||
for (const ShapedBuffer* const arg : arguments) {
|
||||
argument_shapes.push_back(&arg->on_device_shape());
|
||||
argument_shapes.push_back(&arg->on_host_shape());
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(auto options_and_stream,
|
||||
RunHelper(argument_shapes, run_options));
|
||||
@ -324,7 +324,7 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
|
||||
std::vector<const Shape*> argument_shapes;
|
||||
argument_shapes.reserve(arguments.size());
|
||||
for (const ExecutionInput& arg : arguments) {
|
||||
argument_shapes.push_back(&arg.shape());
|
||||
argument_shapes.push_back(&arg.host_shape());
|
||||
}
|
||||
return RunAsync(argument_shapes, std::move(arguments), run_options);
|
||||
}
|
||||
|
@ -173,6 +173,12 @@ static void AllocateFlags() {
|
||||
return true;
|
||||
};
|
||||
|
||||
// Custom "sub-parser" lambda for xla_gpu_llvm_ir_file.
|
||||
auto setter_for_xla_gpu_llvm_ir_file = [](const string& value) {
|
||||
flag_values->add_xla_gpu_llvm_ir_file(value);
|
||||
return true;
|
||||
};
|
||||
|
||||
// Custom "sub-parser" lambda for xla_backend_extra_options.
|
||||
auto setter_for_xla_backend_extra_options =
|
||||
[](string comma_separated_values) {
|
||||
@ -370,7 +376,15 @@ static void AllocateFlags() {
|
||||
"If non-empty, specifies a file containing ptx to use. The filename "
|
||||
"prefix must have the same pattern as PTX dumped by XLA. This allows to "
|
||||
"match one specific module. General workflow. Get the generated module "
|
||||
"ptx from XLA. Modify it. Then pass it back via this option."));
|
||||
"ptx from XLA, modify it, then pass it back via this option."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_gpu_llvm_ir_file", setter_for_xla_gpu_llvm_ir_file, "",
|
||||
"If non-empty, specifies a file containing textual LLVM IR to use. The "
|
||||
"filename prefix must have the same pattern as LLVM dumped by XLA "
|
||||
"(i.e. module_0001.ir-no-opt.ll -> module_0001.MY_NEW_FILE.ll). This "
|
||||
"allows to match one specific module. General workflow. Get the not "
|
||||
"optimized LLVM IR from XLA, modify it, then pass it back via this "
|
||||
"option."));
|
||||
flag_objects->push_back(tensorflow::Flag(
|
||||
"xla_test_all_output_layouts",
|
||||
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),
|
||||
|
@ -1368,8 +1368,7 @@ array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
|
||||
|
||||
The XLA FFT operation implements the forward and inverse Fourier Transforms for
|
||||
real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are
|
||||
supported, except on TPU, where only a single axis is supported (please file a
|
||||
GitHub issue if you require higher order).
|
||||
supported.
|
||||
|
||||
See also
|
||||
[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).
|
||||
|
@ -145,14 +145,14 @@ class LocalDeviceState {
|
||||
// thread or on the worker thread (depending on thread schedules), not a
|
||||
// device callback, so it is safe if the destructor frees device resource
|
||||
// (e.g., GPU objects).
|
||||
// TODO(phawkins): use move-capture when we can use C++14 features.
|
||||
template <typename T>
|
||||
void ThenRelease(se::Stream* stream, T object) const {
|
||||
void ThenRelease(se::Stream* stream, T&& object) const {
|
||||
if (callback_stream_.get() != stream) {
|
||||
callback_stream_->ThenWaitFor(stream);
|
||||
}
|
||||
ThenExecuteOnCallbackThread(callback_stream_.get(),
|
||||
[object]() { /* releases object */ });
|
||||
ThenExecuteOnCallbackThread(
|
||||
callback_stream_.get(),
|
||||
[object = std::forward<T>(object)]() { /* releases object */ });
|
||||
}
|
||||
|
||||
Semaphore& compute_semaphore() { return compute_semaphore_; }
|
||||
|
@ -1340,7 +1340,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper(
|
||||
// have completed.
|
||||
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
|
||||
->local_device_state()
|
||||
->ThenRelease(transfer_stream, src_device_buffer);
|
||||
->ThenRelease(transfer_stream, std::move(src_device_buffer));
|
||||
}
|
||||
return copy_event_or.status();
|
||||
}
|
||||
|
@ -485,12 +485,13 @@ class CopyRemover {
|
||||
};
|
||||
|
||||
CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
|
||||
const HloOrdering& ordering)
|
||||
const HloOrdering& ordering, bool check_live_range_ordering)
|
||||
: dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
|
||||
// Construct a list for each HLO buffer in the alias analysis. Maintain a
|
||||
// map from HloValue to the respective list element representing that
|
||||
// value. The map is used to construct the copy info map below.
|
||||
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
|
||||
// Perform check only if the default dependence-based ordering is used.
|
||||
for (const HloBuffer& buffer : alias_analysis.buffers()) {
|
||||
// No copies should have been inserted within fused computations, so no
|
||||
// need to remove them. HloOrdering isn't compatible with HloValues inside
|
||||
@ -498,24 +499,26 @@ class CopyRemover {
|
||||
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
|
||||
continue;
|
||||
}
|
||||
// Verify values contained in the buffer are strictly ordered. This
|
||||
// should always be the case after adding copies to eliminate
|
||||
// interference. Specifically, the addition of the control flow edges
|
||||
// between copies added around aliased operations (kWhile) guarantees
|
||||
// this strict order.
|
||||
for (const HloValue* value_a : buffer.values()) {
|
||||
if (value_a->shape().IsToken()) {
|
||||
// Token values have no representation and cannot interfere.
|
||||
continue;
|
||||
}
|
||||
for (const HloValue* value_b : buffer.values()) {
|
||||
if (value_a != value_b) {
|
||||
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
|
||||
dataflow_) ||
|
||||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
|
||||
dataflow_))
|
||||
<< value_a->ToShortString() << " and "
|
||||
<< value_b->ToShortString() << " are not ordered";
|
||||
if (check_live_range_ordering) {
|
||||
// Verify values contained in the buffer are strictly ordered. This
|
||||
// should always be the case after adding copies to eliminate
|
||||
// interference. Specifically, the addition of the control flow edges
|
||||
// between copies added around aliased operations (kWhile) guarantees
|
||||
// this strict order.
|
||||
for (const HloValue* value_a : buffer.values()) {
|
||||
if (value_a->shape().IsToken()) {
|
||||
// Token values have no representation and cannot interfere.
|
||||
continue;
|
||||
}
|
||||
for (const HloValue* value_b : buffer.values()) {
|
||||
if (value_a != value_b) {
|
||||
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
|
||||
dataflow_) ||
|
||||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
|
||||
dataflow_))
|
||||
<< value_a->ToString() << " and " << value_b->ToString()
|
||||
<< " are not ordered";
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -729,27 +732,31 @@ class CopyRemover {
|
||||
VLOG(2) << copy->name() << " defines the first value in its buffer";
|
||||
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
|
||||
next_dest = Next(*next_dest)) {
|
||||
// Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
|
||||
if (!LiveRangeBefore(*src, *next_dest)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
// Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
|
||||
for (ValueNode* prev_src = src; prev_src != nullptr;
|
||||
prev_src = Prev(*prev_src)) {
|
||||
if (!LiveRangeBefore(*prev_src, *next_dest)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (ValueNode* next_src = Next(*src); next_src != nullptr;
|
||||
next_src = Next(*next_src)) {
|
||||
// Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
|
||||
ValueNode* last_dest = dest->prev;
|
||||
DCHECK(IsTail(*last_dest));
|
||||
if (!LiveRangeBefore(*last_dest, *next_src)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< last_dest->value->ToShortString() << " is not before "
|
||||
<< next_src->value->ToShortString();
|
||||
return false;
|
||||
for (ValueNode* last_dest = dest->prev; last_dest != nullptr;
|
||||
last_dest = Prev(*dest)) {
|
||||
if (!LiveRangeBefore(*last_dest, *next_src)) {
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< last_dest->value->ToShortString() << " is not before "
|
||||
<< next_src->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "Splice dest after source.";
|
||||
// Splice in destination buffer values list right after 'src'.
|
||||
SpliceAfter(dest, src);
|
||||
} else if (IsTail(*src)) {
|
||||
@ -769,32 +776,36 @@ class CopyRemover {
|
||||
VLOG(2) << copy->name() << " copies the last value ("
|
||||
<< src->value->ToShortString() << ") in its buffer";
|
||||
|
||||
ValueNode* first_src = src->next;
|
||||
DCHECK(IsHead(*first_src));
|
||||
for (ValueNode* prev_dest = Prev(*dest);
|
||||
// nullptr condition handled above in the first 'if' case.
|
||||
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
|
||||
if (!LiveRangeBefore(*prev_dest, *first_src)) {
|
||||
// Live range of value d_{y-1} is not before s_0.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_dest->value->ToShortString() << " is not before "
|
||||
<< first_src->value->ToShortString();
|
||||
return false;
|
||||
for (ValueNode* next_src = src->next; next_src != nullptr;
|
||||
next_src = Next(*next_src)) {
|
||||
for (ValueNode* prev_dest = Prev(*dest);
|
||||
// nullptr condition handled above in the first 'if' case.
|
||||
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
|
||||
if (!LiveRangeBefore(*prev_dest, *next_src)) {
|
||||
// Live range of value d_{y-1} is not before s_0.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_dest->value->ToShortString() << " is not before "
|
||||
<< next_src->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
|
||||
next_dest = Next(*next_dest)) {
|
||||
if (!LiveRangeBefore(*src, *next_dest)) {
|
||||
// Live range of value s_n is not before d_{y+1}.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
for (ValueNode* prev_src = src; prev_src != nullptr;
|
||||
prev_src = Prev(*prev_src)) {
|
||||
if (!LiveRangeBefore(*prev_src, *next_dest)) {
|
||||
// Live range of value s_n is not before d_{y+1}.
|
||||
VLOG(2) << "Not removing the copy: live range of "
|
||||
<< prev_src->value->ToShortString() << " is not before "
|
||||
<< next_dest->value->ToShortString();
|
||||
return false;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
VLOG(2) << "Splice src after prev of dest.";
|
||||
// Splice source buffer values list right after 'prev_dest'.
|
||||
SpliceAfter(first_src, Prev(*dest));
|
||||
SpliceAfter(src->next, Prev(*dest));
|
||||
} else {
|
||||
VLOG(2) << copy->name()
|
||||
<< " copies value in middle of source buffer to value in middle "
|
||||
@ -1175,12 +1186,14 @@ static int64 GetNumExistingCopies(const HloModule* module) {
|
||||
}
|
||||
|
||||
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
HloModule* module) {
|
||||
HloModule* module,
|
||||
bool check_live_range_ordering) {
|
||||
XLA_VLOG_LINES(4, module->ToString());
|
||||
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
|
||||
HloAliasAnalysis::Run(module, can_share_buffer_));
|
||||
|
||||
CopyRemover copy_remover(*module, *alias_analysis, ordering);
|
||||
CopyRemover copy_remover(*module, *alias_analysis, ordering,
|
||||
check_live_range_ordering);
|
||||
if (VLOG_IS_ON(3)) {
|
||||
LOG(INFO) << "Removing unnecessary copies in " << module->name();
|
||||
LOG(INFO) << "Buffer values, in dependency order: ";
|
||||
@ -1200,7 +1213,9 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
VLOG(2) << "Running fixpoint iteration " << num_iterations
|
||||
<< " of copy elision";
|
||||
for (HloComputation* computation : module->computations()) {
|
||||
VLOG(2) << "computation:" << computation->name() << "\n";
|
||||
for (HloInstruction* instruction : computation->instructions()) {
|
||||
VLOG(2) << instruction->ToString() << "\n";
|
||||
if (instruction->opcode() == HloOpcode::kCopy &&
|
||||
copy_remover.TryElideCopy(instruction)) {
|
||||
changed = true;
|
||||
@ -1260,7 +1275,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
|
||||
name(), "after adding copies to resolve interference", *module);
|
||||
|
||||
TF_RETURN_IF_ERROR(
|
||||
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
|
||||
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module,
|
||||
/*check_live_range_ordering=*/true));
|
||||
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
|
||||
*module);
|
||||
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));
|
||||
|
@ -63,8 +63,10 @@ class CopyInsertion : public HloModulePass {
|
||||
// Try to remove as many copies from the module as possible without
|
||||
// introducing live range interference. Only copy instructions that are
|
||||
// eligible for copy elision are considered for removal.
|
||||
Status RemoveUnnecessaryCopies(const HloOrdering& ordering,
|
||||
HloModule* module);
|
||||
// If check_live_range_ordering is true, check that live ranges are ordered
|
||||
// in all the existing aliased buffers.
|
||||
Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module,
|
||||
bool check_live_range_ordering = false);
|
||||
|
||||
// Add copies to address special constraints on the roots of computations not
|
||||
// related to live range interference:
|
||||
|
@ -2877,6 +2877,56 @@ ENTRY main {
|
||||
EXPECT_EQ(CountCopies(*module), 1);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
add0 = f32[10, 20] add(p0, p1)
|
||||
sub0 = f32[10, 10] subtract(p2, p3)
|
||||
reshape0 = f32[200] reshape(add0)
|
||||
reshape1 = f32[100] reshape(sub0)
|
||||
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
|
||||
slice0 = f32[200] slice(concat0), slice={[0:200]}
|
||||
slice1 = f32[100] slice(concat0), slice={[200:300]}
|
||||
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
|
||||
gte0 = f32[200] get-tuple-element(fusion), index=0
|
||||
gte1 = f32[100] get-tuple-element(fusion), index=1
|
||||
bitcast0 = f32[10,20] bitcast(gte0)
|
||||
bitcast1 = f32[10,10] bitcast(gte1)
|
||||
ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
|
||||
}
|
||||
)";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
|
||||
ParseAndReturnVerifiedModule(hlo_string));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{0},
|
||||
/*param_number=*/0,
|
||||
/*param_index=*/{}));
|
||||
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
|
||||
/*output_index=*/{1},
|
||||
/*param_number=*/3,
|
||||
/*param_index=*/{}));
|
||||
|
||||
InsertCopies(module.get());
|
||||
|
||||
// There should be no copies inserted.
|
||||
EXPECT_EQ(CountCopies(*module), 0);
|
||||
}
|
||||
|
||||
TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
|
||||
const string& hlo_string = R"(
|
||||
HloModule TestModule
|
||||
|
@ -1405,6 +1405,8 @@ cc_library(
|
||||
"//tensorflow/stream_executor:stream_executor_headers",
|
||||
"//tensorflow/stream_executor/cuda:cuda_diagnostics",
|
||||
"//tensorflow/stream_executor/gpu:asm_compiler",
|
||||
"@llvm-project//llvm:IRReader",
|
||||
"@llvm-project//llvm:Support",
|
||||
]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
|
||||
)
|
||||
|
||||
|
@ -174,14 +174,6 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) {
|
||||
return false;
|
||||
}
|
||||
|
||||
// We can emit DUS in-place, horizontally fusing it makes the emitter no
|
||||
// longer recognize that it can be done in-place. This creates much slower
|
||||
// code. This restriction could be lifted if buffer assignment would recognize
|
||||
// that the DUS can be done in-place even inside of a horizontal fusion.
|
||||
if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
@ -203,6 +195,19 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
|
||||
return true;
|
||||
}
|
||||
|
||||
// Returns whether any operand of `instr` is a parameter instruction that
|
||||
// is shared with `fusion_instrs`.
|
||||
bool AnyOpndIsParamSharedAmongFusions(
|
||||
const HloInstruction* instr,
|
||||
const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
|
||||
return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
|
||||
return opnd->opcode() == HloOpcode::kParameter &&
|
||||
absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
|
||||
return user != instr && fusion_instrs.contains(user);
|
||||
});
|
||||
});
|
||||
}
|
||||
|
||||
void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
|
||||
HloInstruction* consumer) {
|
||||
// First, find out all fusion instructions. We will filter out
|
||||
@ -230,6 +235,14 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
|
||||
} else if (!HasOnlyRowMajorLayout(*instr)) {
|
||||
VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
|
||||
continue;
|
||||
} else if (AnyOpndIsParamSharedAmongFusions(instr, fusion_instrs)) {
|
||||
// Don't fuse fusions whose operands are parameter instructions that are
|
||||
// shared among fusions because we cannot i/o alias the produced
|
||||
// horizontal fusion due to the concat insertion.
|
||||
VLOG(2) << "Reject the fusion instr because it shares parameter with"
|
||||
<< " other fusion candidates, instr: ",
|
||||
instr->ToString();
|
||||
continue;
|
||||
} else {
|
||||
VLOG(2) << "Find a fusion candidate " << instr->ToString();
|
||||
fusion_instrs_.push_back(instr);
|
||||
|
@ -364,33 +364,33 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
|
||||
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
|
||||
}
|
||||
|
||||
TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
|
||||
TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule NegativeTestForDynamicUpdateSlice
|
||||
|
||||
fusion.1 {
|
||||
p.0 = f16[5,9,10]{2,1,0} parameter(0)
|
||||
p.1 = s32[1]{0} parameter(1)
|
||||
p.1 = s32[] parameter(1)
|
||||
p.2 = f16[1,9,10]{2,1,0} parameter(2)
|
||||
c.0 = s32[] constant(0)
|
||||
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
|
||||
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
|
||||
ROOT %dynamic-update-slice =
|
||||
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
|
||||
}
|
||||
|
||||
fusion.2 {
|
||||
p.0 = f16[5,9,10]{2,1,0} parameter(0)
|
||||
p.1 = s32[1]{0} parameter(1)
|
||||
p.1 = s32[] parameter(1)
|
||||
p.2 = f16[1,9,10]{2,1,0} parameter(2)
|
||||
c.0 = s32[] constant(0)
|
||||
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
|
||||
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
|
||||
ROOT %dynamic-update-slice =
|
||||
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
|
||||
}
|
||||
|
||||
ENTRY entry {
|
||||
p.00 = f16[5,9,10]{2,1,0} parameter(0)
|
||||
p.01 = f16[5,9,10]{2,1,0} parameter(1)
|
||||
p.10 = s32[1]{0} parameter(2)
|
||||
p.11 = s32[1]{0} parameter(3)
|
||||
p.10 = s32[] parameter(2)
|
||||
p.11 = s32[] parameter(3)
|
||||
p.20 = f16[1,9,10]{2,1,0} parameter(4)
|
||||
p.21 = f16[1,9,10]{2,1,0} parameter(5)
|
||||
|
||||
@ -400,6 +400,46 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
|
||||
})")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
|
||||
EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
|
||||
|
||||
VLOG(2) << "Dump after horizontal fusion:";
|
||||
VLOG(2) << module->ToString();
|
||||
|
||||
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
|
||||
}
|
||||
|
||||
TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
|
||||
auto module = ParseAndReturnVerifiedModule(R"(
|
||||
HloModule BasicTest
|
||||
|
||||
fused_computation.1 {
|
||||
arg.1 = f16[123]{0} parameter(0)
|
||||
arg.2 = f16[123]{0} parameter(1)
|
||||
ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
|
||||
}
|
||||
|
||||
fused_computation.2 {
|
||||
arg.1 = f16[123]{0} parameter(0)
|
||||
arg.2 = f16[123]{0} parameter(1)
|
||||
ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
|
||||
}
|
||||
|
||||
ENTRY entry_computation {
|
||||
arg.1 = f16[123]{0} parameter(0)
|
||||
// arg.2 is shared by fusion.1 and fusion.2
|
||||
arg.2 = f16[123]{0} parameter(1)
|
||||
arg.3 = f16[123]{0} parameter(2)
|
||||
fusion.1 = f16[123]{0}
|
||||
fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
|
||||
fusion.2 = f16[123]{0}
|
||||
fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
|
||||
ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
|
||||
tuple(fusion.1, fusion.2)
|
||||
}
|
||||
)")
|
||||
.ValueOrDie();
|
||||
|
||||
EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
|
||||
}
|
||||
|
||||
|
@ -299,6 +299,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
|
||||
(f64_atomic_add_supported && element_type == F64);
|
||||
if (atomic_add_supported) {
|
||||
AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
|
||||
llvm::MaybeAlign(),
|
||||
llvm::AtomicOrdering::SequentiallyConsistent);
|
||||
return true;
|
||||
}
|
||||
@ -307,6 +308,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
|
||||
if (is_atomic_integral) {
|
||||
// integral + integral
|
||||
AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
|
||||
llvm::MaybeAlign(),
|
||||
llvm::AtomicOrdering::SequentiallyConsistent);
|
||||
return true;
|
||||
}
|
||||
@ -318,7 +320,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
|
||||
auto opcode = primitive_util::IsSignedIntegralType(element_type)
|
||||
? llvm::AtomicRMWInst::Max
|
||||
: llvm::AtomicRMWInst::UMax;
|
||||
AtomicRMW(opcode, output_address, source,
|
||||
AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
|
||||
llvm::AtomicOrdering::SequentiallyConsistent);
|
||||
return true;
|
||||
}
|
||||
@ -328,7 +330,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
|
||||
auto opcode = primitive_util::IsSignedIntegralType(element_type)
|
||||
? llvm::AtomicRMWInst::Min
|
||||
: llvm::AtomicRMWInst::UMin;
|
||||
AtomicRMW(opcode, output_address, source,
|
||||
AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
|
||||
llvm::AtomicOrdering::SequentiallyConsistent);
|
||||
return true;
|
||||
}
|
||||
@ -476,10 +478,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
|
||||
// Emit code to perform the atomicCAS operation
|
||||
// (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
|
||||
// cas_new_output);
|
||||
llvm::Value* ret_value =
|
||||
AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
|
||||
llvm::AtomicOrdering::SequentiallyConsistent,
|
||||
llvm::AtomicOrdering::SequentiallyConsistent);
|
||||
llvm::Value* ret_value = AtomicCmpXchg(
|
||||
atomic_memory_address, cas_old_output, cas_new_output, llvm::MaybeAlign(),
|
||||
llvm::AtomicOrdering::SequentiallyConsistent,
|
||||
llvm::AtomicOrdering::SequentiallyConsistent);
|
||||
|
||||
// Extract the memory value returned from atomicCAS and store it as
|
||||
// cas_old_output.
|
||||
|
@ -20,6 +20,8 @@ limitations under the License.
|
||||
#include <fstream>
|
||||
|
||||
#include "absl/base/call_once.h"
|
||||
#include "llvm/IRReader/IRReader.h"
|
||||
#include "llvm/Support/SourceMgr.h"
|
||||
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
|
||||
#include "tensorflow/compiler/xla/service/dump.h"
|
||||
#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
|
||||
@ -202,7 +204,7 @@ absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
|
||||
// Try to load ptx from files defined in the FLAGS. If successful, return true.
|
||||
bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
|
||||
const HloModule* module, std::string* ptx) {
|
||||
// If the xla_gpu_ptx_file options is set, be explicit when a file is used
|
||||
// If the xla_gpu_ptx_file option is set, be explicit if a file is used
|
||||
// and warn when a file is not used to ease catching typo in filename.
|
||||
std::string prefix = xla::FilenameFor(*module, "", *ptx);
|
||||
std::string matched_filename;
|
||||
@ -234,6 +236,50 @@ bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
|
||||
return false;
|
||||
}
|
||||
|
||||
// Try to load textual LLVM IR from files defined in the FLAGS. If
|
||||
// successful, return the llvm::Module, otherwise return nullptr.
|
||||
std::unique_ptr<llvm::Module> MaybeLoadLLVMFromFile(const HloModule* module,
|
||||
llvm::Module* llvm_module) {
|
||||
// If the xla_gpu_llvm_ir_file option is set, be explicit if a file is used
|
||||
// and warn when a file is not used to ease catching typo in filename.
|
||||
if (module == nullptr) {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
std::string prefix = xla::FilenameFor(*module, "", "");
|
||||
auto xla_gpu_llvm_ir_file =
|
||||
module->config().debug_options().xla_gpu_llvm_ir_file();
|
||||
auto matched_filename = absl::c_find_if(
|
||||
xla_gpu_llvm_ir_file, [prefix](const string& full_filename) {
|
||||
// To ease comparing many LLVM versions, accept different suffixes then
|
||||
// the original filename.
|
||||
return absl::StartsWith(tensorflow::io::Basename(full_filename),
|
||||
prefix);
|
||||
});
|
||||
if (!xla_gpu_llvm_ir_file.empty() &&
|
||||
matched_filename == std::end(xla_gpu_llvm_ir_file)) {
|
||||
VLOG(0) << "RunBackend() - For module with prefix '" << prefix
|
||||
<< "', we did not found a LLVM file to load.";
|
||||
}
|
||||
|
||||
if (matched_filename != std::end(xla_gpu_llvm_ir_file)) {
|
||||
VLOG(0) << "RunBackend() - Will load LLVM from file: " << *matched_filename;
|
||||
llvm::LLVMContext& context = llvm_module->getContext();
|
||||
llvm::SMDiagnostic err;
|
||||
std::unique_ptr<llvm::Module> loaded_module =
|
||||
llvm::parseIRFile(*matched_filename, err, context);
|
||||
|
||||
if (!loaded_module) {
|
||||
err.print("ERR", llvm::errs());
|
||||
LOG(FATAL) << "Failed to load an LLVM file. It is probably invalid LLVM.";
|
||||
}
|
||||
// Overwrite the dumped not optimized LLVM to show which one will be used.
|
||||
llvm_ir::DumpIrIfEnabled(*module, *loaded_module, /*optimized=*/false);
|
||||
return loaded_module;
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
// Prints a warning if the ptx->sass JIT in the driver has known bugs.
|
||||
@ -320,13 +366,21 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
|
||||
libdevice_dir = cached_libdevice_dir_;
|
||||
}
|
||||
VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n";
|
||||
std::unique_ptr<llvm::Module> loaded_module =
|
||||
MaybeLoadLLVMFromFile(debug_module, llvm_module);
|
||||
llvm::Module* selected_module = nullptr;
|
||||
if (loaded_module) {
|
||||
selected_module = loaded_module.get();
|
||||
} else {
|
||||
selected_module = llvm_module;
|
||||
}
|
||||
|
||||
string ptx;
|
||||
if (!(debug_module &&
|
||||
MaybeLoadPtxFromFile(module_config, debug_module, &ptx))) {
|
||||
XLA_SCOPED_LOGGING_TIMER(
|
||||
"NVPTXCompiler::CompileTargetBinary - CompileToPtx");
|
||||
TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(llvm_module, gpu_version,
|
||||
TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(selected_module, gpu_version,
|
||||
module_config, libdevice_dir));
|
||||
}
|
||||
|
||||
|
@ -120,6 +120,175 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
|
||||
return true;
|
||||
}
|
||||
|
||||
namespace {
|
||||
bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
|
||||
return instr->opcode() == HloOpcode::kSlice &&
|
||||
1 == instr->slice_starts().size() &&
|
||||
1 == instr->slice_limits().size() &&
|
||||
1 == instr->slice_strides().size() &&
|
||||
1 == instr->slice_strides().at(0);
|
||||
}
|
||||
|
||||
bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
|
||||
if (!unnested_hlo.IsInputFusion()) {
|
||||
return false;
|
||||
}
|
||||
const HloInstruction* root = unnested_hlo.fused_expression_root();
|
||||
if (root->opcode() != HloOpcode::kTuple) {
|
||||
return false;
|
||||
}
|
||||
return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
|
||||
return Is1dSliceWithoutStrides(instr);
|
||||
});
|
||||
}
|
||||
|
||||
struct ConcatUsageInfo {
|
||||
// Pointer to a previously seen concat. nullptr if no previously seen concat.
|
||||
const HloInstruction* prev_concat;
|
||||
// The opnd id of the seen concat.
|
||||
int64 concat_opnd_idx;
|
||||
// The slice that recovers the opnd in the concat outputs.
|
||||
const HloInstruction* slice_to_recover_opnd;
|
||||
};
|
||||
|
||||
// Returns an optional concat usage info to denote whether the concat is used in
|
||||
// an elementwise manner. A concat followed by slices is considered effectively
|
||||
// elementwise if the slices combinedly is a reverse function of the concat.
|
||||
absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
|
||||
const HloInstruction& concat, const HloInstruction& operand,
|
||||
const ConcatUsageInfo& info) {
|
||||
// First, check if this concat is in the below pattern. Also, we check
|
||||
// that the slices combinedly are in effect a reverse function of the concat.
|
||||
//
|
||||
// Concat
|
||||
// | |
|
||||
// v v
|
||||
// Slice Slice
|
||||
//
|
||||
std::vector<HloInstruction*> users = concat.users();
|
||||
if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
|
||||
// Limit our supported cases to 1 dimensional slices.
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
// Verify that each operand to the concat is reversed by a slice.
|
||||
if (users.size() != concat.operand_count() ||
|
||||
concat.operand_count() != concat.unique_operands().size()) {
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
|
||||
return a->slice_starts().at(0) < b->slice_starts().at(0);
|
||||
});
|
||||
int64 prev_limit = 0;
|
||||
for (int64 i = 0; i < users.size(); ++i) {
|
||||
const HloInstruction* u = users[i];
|
||||
int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
|
||||
if (u->slice_starts().at(0) != prev_limit ||
|
||||
slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
prev_limit = u->slice_limits().at(0);
|
||||
}
|
||||
|
||||
// If we have seen other concats, make sure they are identical. Multiple
|
||||
// concats exist because horizontal fusion inserts one concat for each output
|
||||
// of the fusion candidates. Check that all concats and operand ids are the
|
||||
// same to know that the "transitive use closure" will be computed in the same
|
||||
// iteration space.
|
||||
int64 operand_idx = concat.operand_index(&operand);
|
||||
if (info.prev_concat != nullptr) {
|
||||
bool is_concat_identical = info.prev_concat->Identical(
|
||||
concat,
|
||||
/*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
|
||||
// Operands don't need to be the same.
|
||||
return true;
|
||||
});
|
||||
if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
|
||||
return absl::optional<ConcatUsageInfo>();
|
||||
}
|
||||
}
|
||||
|
||||
const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
|
||||
return absl::optional<ConcatUsageInfo>(
|
||||
ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
|
||||
}
|
||||
|
||||
// Returns whether we can prove the transitive uses of `param` are in effect
|
||||
// elementwise. In other words, we prove that the "transitive use closure" will
|
||||
// all be computed in the same iteration space without any reorder of elements.
|
||||
// In addition, we check that the "transitive use closure" includes the output
|
||||
// in the `root_tuple`.
|
||||
// Theoretically, We can prove more patterns but our primary use case is
|
||||
// SliceInputFusion.
|
||||
bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
|
||||
const HloInstruction* root_tuple,
|
||||
const ShapeIndex& out_shape_idx) {
|
||||
CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
|
||||
CHECK_EQ(out_shape_idx.size(), 1);
|
||||
absl::flat_hash_set<const HloInstruction*> visited;
|
||||
absl::InlinedVector<const HloInstruction*, 4> stack;
|
||||
stack.push_back(param);
|
||||
ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
|
||||
bool is_output_reachable = false;
|
||||
while (!stack.empty()) {
|
||||
const HloInstruction* current = stack.back();
|
||||
stack.pop_back();
|
||||
visited.insert(current);
|
||||
for (const HloInstruction* user : current->users()) {
|
||||
VLOG(3) << "Visiting: " << user->ToString();
|
||||
switch (user->opcode()) {
|
||||
case HloOpcode::kTuple:
|
||||
if (user == root_tuple &&
|
||||
current == root_tuple->operand(out_shape_idx.back())) {
|
||||
// We need to know if the output is reachable by the `param` to make
|
||||
// sure that they will be computed in the same iteration space.
|
||||
is_output_reachable = true;
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kReshape:
|
||||
if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
case HloOpcode::kConcatenate: {
|
||||
absl::optional<ConcatUsageInfo> optional_concat_info =
|
||||
ConcatIsEffectivelyElementwise(*user, *current,
|
||||
concat_usage_info);
|
||||
if (!optional_concat_info) {
|
||||
return false;
|
||||
}
|
||||
concat_usage_info = *optional_concat_info;
|
||||
// Early continue as we only want to traverse through the slice that
|
||||
// recovers the operand. It is guaranteed that the operand to the
|
||||
// concat and the slice have the same iteration space. Insert the
|
||||
// slice instead of the concat.
|
||||
CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
|
||||
stack.push_back(concat_usage_info.slice_to_recover_opnd);
|
||||
continue;
|
||||
}
|
||||
default:
|
||||
for (const int64 use_index : user->OperandIndices(current)) {
|
||||
if (!user->IsElementwiseOnOperand(use_index)) {
|
||||
// Found a user that is non-elementwise on the current
|
||||
// instruction.
|
||||
return false;
|
||||
}
|
||||
}
|
||||
if (!LayoutUtil::Equal(current->shape().layout(),
|
||||
user->shape().layout())) {
|
||||
// Make sure the layout is not changed by the elementwise op.
|
||||
return false;
|
||||
}
|
||||
break;
|
||||
} // end of switch
|
||||
if (!visited.contains(user)) {
|
||||
stack.push_back(user);
|
||||
}
|
||||
}
|
||||
}
|
||||
return is_output_reachable;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
|
||||
const ShapeIndex& index) const {
|
||||
const HloValueSet& value_set = GetValueSet(instruction, index);
|
||||
@ -1266,10 +1435,23 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
|
||||
if (operand->opcode() == HloOpcode::kConstant) {
|
||||
return false;
|
||||
}
|
||||
|
||||
const Shape& operand_subshape =
|
||||
ShapeUtil::GetSubshape(operand->shape(), operand_index);
|
||||
const Shape& user_subshape =
|
||||
ShapeUtil::GetSubshape(user->shape(), user_index);
|
||||
if (IsSliceInputFusion(*user)) {
|
||||
HloInstruction* fusion_param =
|
||||
user->fused_parameter(user->operand_index(operand));
|
||||
// We don't require the same dimensions but only the same number of elements
|
||||
// and type (to make sure the same buffer size).
|
||||
return operand_subshape.IsArray() && user_subshape.IsArray() &&
|
||||
ShapeUtil::ElementsIn(operand_subshape) ==
|
||||
ShapeUtil::ElementsIn(user_subshape) &&
|
||||
ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
|
||||
AreTransitiveUsesEffectivelyElementwise(
|
||||
fusion_param, user->fused_expression_root(), user_index);
|
||||
}
|
||||
|
||||
// Check that operand and user emit the same shape and layout.
|
||||
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {
|
||||
|
@ -2795,5 +2795,150 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
|
||||
dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
add0 = f32[10, 20] add(p0, p1)
|
||||
sub0 = f32[10, 10] subtract(p2, p3)
|
||||
reshape0 = f32[200] reshape(add0)
|
||||
reshape1 = f32[100] reshape(sub0)
|
||||
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
|
||||
slice0 = f32[200] slice(concat0), slice={[0:200]}
|
||||
slice1 = f32[100] slice(concat0), slice={[200:300]}
|
||||
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[10,20] parameter(0)
|
||||
p1 = f32[10,20] parameter(1)
|
||||
p2 = f32[10,10] parameter(2)
|
||||
p3 = f32[10,10] parameter(3)
|
||||
ROOT fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param0 = module_->entry_computation()->parameter_instruction(0);
|
||||
auto* param1 = module_->entry_computation()->parameter_instruction(1);
|
||||
auto* param2 = module_->entry_computation()->parameter_instruction(2);
|
||||
auto* param3 = module_->entry_computation()->parameter_instruction(3);
|
||||
|
||||
RunAnalysis();
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {0}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {0}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param2, {},
|
||||
fusion, {1}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param3, {},
|
||||
fusion, {1}));
|
||||
// Tensors of different sizes cannot share buffer.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {1}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
// p0 has multiple transitive uses fed to concat. So, p0 cannot share
|
||||
// buffer with outputs because the aliased output could be written before
|
||||
// all the uses of p0 are finished.
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
add0 = f32[100] add(p0, p1)
|
||||
concat0 = f32[200] concatenate(p0, add0), dimensions={0}
|
||||
slice0 = f32[100] slice(concat0), slice={[0:100]}
|
||||
slice1 = f32[100] slice(concat0), slice={[100:200]}
|
||||
ROOT tuple = (f32[100], f32[100]) tuple(slice0, slice1)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
ROOT fusion = (f32[100], f32[100]) fusion(p0, p1),
|
||||
kind=kInput, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param0 = module_->entry_computation()->parameter_instruction(0);
|
||||
auto* param1 = module_->entry_computation()->parameter_instruction(1);
|
||||
|
||||
RunAnalysis();
|
||||
// p0 cannot share with either fusion{0} or fusion{1}.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {0}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {1}));
|
||||
// p1 cannot share with fusion{0} because we're not sure about their
|
||||
// relationship.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {0}));
|
||||
// p1 can share with fusion{1} because they will be executed in an
|
||||
// elementwise manner.
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {1}));
|
||||
}
|
||||
|
||||
TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) {
|
||||
const char* kModule = R"(
|
||||
HloModule test
|
||||
|
||||
fused_computation {
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
add0 = f32[100] add(p0, p1)
|
||||
sub0 = f32[100] subtract(p1, p1)
|
||||
concat0 = f32[200] concatenate(p0, add0), dimensions={0}
|
||||
slice0 = f32[100] slice(concat0), slice={[0:100]}
|
||||
slice1 = f32[100] slice(concat0), slice={[100:200]}
|
||||
concat1 = f32[200] concatenate(p0, sub0), dimensions={0}
|
||||
slice2 = f32[100] slice(concat1), slice={[0:100]}
|
||||
slice3 = f32[100] slice(concat1), slice={[100:200]}
|
||||
ROOT tuple = (f32[100], f32[100], f32[100], f32[100])
|
||||
tuple(slice0, slice1, slice2, slice3)
|
||||
}
|
||||
|
||||
ENTRY test {
|
||||
p0 = f32[100] parameter(0)
|
||||
p1 = f32[100] parameter(1)
|
||||
ROOT fusion = (f32[100], f32[100], f32[100], f32[100])
|
||||
fusion(p0, p1), kind=kInput, calls=fused_computation
|
||||
}
|
||||
)";
|
||||
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
|
||||
auto* fusion = module_->entry_computation()->root_instruction();
|
||||
auto* param0 = module_->entry_computation()->parameter_instruction(0);
|
||||
auto* param1 = module_->entry_computation()->parameter_instruction(1);
|
||||
|
||||
RunAnalysis();
|
||||
// p0 cannot share.
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {0}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {1}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {2}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
|
||||
fusion, {3}));
|
||||
// p1 can share with either fusion{1} or fusion{3}.
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {1}));
|
||||
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {3}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {0}));
|
||||
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
|
||||
fusion, {2}));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -1933,36 +1933,19 @@ Status LayoutAssignment::RunOnComputation(
|
||||
|
||||
// Copy the root instruction's result if its layout does not match the result
|
||||
// layout constraint.
|
||||
if (constraints.ResultLayout() != nullptr) {
|
||||
// Layout assignment at this point only does minor-to-major assignment so
|
||||
// tiling info should be ignored here for comparison.
|
||||
if (!constraints.ResultLayout()->MatchesLayoutInShape(
|
||||
computation->root_instruction()->shape(),
|
||||
/*minor_to_major_only=*/true)) {
|
||||
if (conditional_mismatch_.count(computation) > 0) {
|
||||
*FindOrDie(computation_layouts_, computation).mutable_result_layout() =
|
||||
FindOrDie(conditional_mismatch_, computation).result_layout();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_root,
|
||||
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
|
||||
computation->root_instruction()));
|
||||
computation->set_root_instruction(new_root);
|
||||
} else {
|
||||
// Copy the specified tiling info.
|
||||
auto assign_tiling = [&constraints](xla::Shape* subshape,
|
||||
const xla::ShapeIndex& index) {
|
||||
if (subshape->IsArray()) {
|
||||
const Shape& result_shape = ShapeUtil::GetSubshape(
|
||||
constraints.ResultLayout()->shape(), index);
|
||||
subshape->mutable_layout()->mutable_tiles()->assign(
|
||||
result_shape.layout().tiles().begin(),
|
||||
result_shape.layout().tiles().end());
|
||||
}
|
||||
};
|
||||
xla::ShapeUtil::ForEachMutableSubshape(
|
||||
computation->root_instruction()->mutable_shape(), assign_tiling);
|
||||
if (constraints.ResultLayout() != nullptr &&
|
||||
!constraints.ResultLayout()->MatchesLayoutInShape(
|
||||
computation->root_instruction()->shape(),
|
||||
/*minor_to_major_only=*/true)) {
|
||||
if (conditional_mismatch_.count(computation) > 0) {
|
||||
*FindOrDie(computation_layouts_, computation).mutable_result_layout() =
|
||||
FindOrDie(conditional_mismatch_, computation).result_layout();
|
||||
}
|
||||
TF_ASSIGN_OR_RETURN(
|
||||
HloInstruction * new_root,
|
||||
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
|
||||
computation->root_instruction()));
|
||||
computation->set_root_instruction(new_root);
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
@ -441,8 +441,7 @@ absl::flat_hash_map<int, HloInstruction*> CreateSinkedAllReduces(
|
||||
}
|
||||
CHECK(ContainsKey(tuple_index_to_old_buffer, tuple_index));
|
||||
HloInstruction* old_buffer = tuple_index_to_old_buffer.at(tuple_index);
|
||||
CHECK(Shape::Equal().IgnoreLayout()(old_buffer->shape(),
|
||||
all_reduced_delta->shape()));
|
||||
CHECK(ShapeUtil::Equal(old_buffer->shape(), all_reduced_delta->shape()));
|
||||
HloInstruction* add_to_old_buffer =
|
||||
while_parent->AddInstruction(HloInstruction::CreateBinary(
|
||||
all_reduced_delta->shape(), HloOpcode::kAdd, old_buffer,
|
||||
|
@ -618,8 +618,8 @@ optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
|
||||
}
|
||||
|
||||
Literal cond_result_pred = std::move(eval_result.ValueOrDie());
|
||||
CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
|
||||
ShapeUtil::MakeShape(PRED, {})));
|
||||
CHECK(ShapeUtil::Equal(cond_result_pred.shape(),
|
||||
ShapeUtil::MakeShape(PRED, {})));
|
||||
|
||||
// Per the explanation above, if the evaluated condition returns false, the
|
||||
// loop executes at most once.
|
||||
|
@ -314,7 +314,10 @@ message DebugOptions {
|
||||
// Compilation errors out if these ops are encountered.
|
||||
bool xla_gpu_deterministic_ops = 148;
|
||||
|
||||
// Next id: 150
|
||||
// Paths to files with LLVM code.
|
||||
repeated string xla_gpu_llvm_ir_file = 150;
|
||||
|
||||
// Next id: 151
|
||||
|
||||
// Extra options to pass to the compilation backend (e.g. LLVM); specific
|
||||
// interpretation of these values is left to the backend.
|
||||
|
@ -602,7 +602,6 @@ cc_library(
|
||||
"//tensorflow/core/kernels:batch_kernels",
|
||||
"//tensorflow/core/kernels:bincount_op",
|
||||
"//tensorflow/core/kernels:boosted_trees_ops",
|
||||
"//tensorflow/core/kernels:tensor_forest_ops",
|
||||
"//tensorflow/core/kernels:candidate_sampler_ops",
|
||||
"//tensorflow/core/kernels:checkpoint_ops",
|
||||
"//tensorflow/core/kernels:clustering_ops",
|
||||
@ -994,7 +993,6 @@ filegroup(
|
||||
"stateless_random_ops_v2_op_lib",
|
||||
"string_ops_op_lib",
|
||||
"summary_ops_op_lib",
|
||||
"tensor_forest_ops_op_lib",
|
||||
"tpu_configuration_ops_op_lib",
|
||||
"tpu_cross_replica_ops_op_lib",
|
||||
"tpu_embedding_ops_op_lib",
|
||||
|
@ -1,17 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestCreateTreeVariable"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource to be created.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tree_config"
|
||||
description: <<END
|
||||
Serialized proto string of the boosted_trees.Tree.
|
||||
END
|
||||
}
|
||||
summary: "Creates a tree resource and returns a handle to it."
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeDeserialize"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource to be restored.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tree_config"
|
||||
description: <<END
|
||||
Serialied proto string of the boosted_trees.Tree proto.
|
||||
END
|
||||
}
|
||||
summary: "Deserializes a proto into the tree handle"
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeIsInitializedOp"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "is_initialized"
|
||||
description: <<END
|
||||
Whether the tree is initialized.
|
||||
END
|
||||
}
|
||||
summary: "Checks whether a tree has been initialized."
|
||||
}
|
@ -1,29 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreePredict"
|
||||
visibility: HIDDEN
|
||||
attr {
|
||||
name: "logits_dimension"
|
||||
description: <<END
|
||||
Scalar, dimension of the logits.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource.
|
||||
END
|
||||
}
|
||||
in_arg {
|
||||
name: "dense_features"
|
||||
description: <<END
|
||||
Rank 2 dense features tensor.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "logits"
|
||||
description: <<END
|
||||
The logits predictions from the tree for each instance in the batch.
|
||||
END
|
||||
}
|
||||
summary: "Output the logits for the given input data"
|
||||
}
|
@ -1,5 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeResourceHandleOp"
|
||||
visibility: HIDDEN
|
||||
summary: "Creates a handle to a TensorForestTreeResource"
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeSerialize"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource to be serialized.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "tree_config"
|
||||
description: <<END
|
||||
Serialied proto string of the tree resource.
|
||||
END
|
||||
}
|
||||
summary: "Serializes the tree handle to a proto"
|
||||
}
|
@ -1,17 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeSize"
|
||||
visibility: HIDDEN
|
||||
in_arg {
|
||||
name: "tree_handle"
|
||||
description: <<END
|
||||
Handle to the tree resource.
|
||||
END
|
||||
}
|
||||
out_arg {
|
||||
name: "tree_size"
|
||||
description: <<END
|
||||
The size of the tree.
|
||||
END
|
||||
}
|
||||
summary: "Get the number of nodes in a tree"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestCreateTreeVariable"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeDeserialize"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeIsInitializedOp"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreePredict"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeResourceHandleOp"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeSerialize"
|
||||
}
|
@ -1,3 +0,0 @@
|
||||
op {
|
||||
graph_op_name: "TensorForestTreeSize"
|
||||
}
|
@ -191,19 +191,33 @@ Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
|
||||
type.name());
|
||||
}
|
||||
|
||||
Status ResourceMgr::Lookup(const ResourceHandle& handle,
|
||||
ResourceBase** resource) const {
|
||||
tf_shared_lock l(mu_);
|
||||
return DoLookup(handle.container(), handle.hash_code(),
|
||||
/*type_name=*/"ResourceBase", handle.name(), resource);
|
||||
}
|
||||
|
||||
Status ResourceMgr::DoLookup(const string& container, TypeIndex type,
|
||||
const string& name,
|
||||
ResourceBase** resource) const {
|
||||
return DoLookup(container, type.hash_code(), type.name(), name, resource);
|
||||
}
|
||||
|
||||
Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code,
|
||||
const string& type_name,
|
||||
const string& resource_name,
|
||||
ResourceBase** resource) const {
|
||||
const Container* b = gtl::FindPtrOrNull(containers_, container);
|
||||
if (b == nullptr) {
|
||||
return errors::NotFound("Container ", container,
|
||||
" does not exist. (Could not find resource: ",
|
||||
container, "/", name, ")");
|
||||
container, "/", resource_name, ")");
|
||||
}
|
||||
auto iter = b->find({type.hash_code(), name});
|
||||
auto iter = b->find({type_hash_code, resource_name});
|
||||
if (iter == b->end()) {
|
||||
return errors::NotFound("Resource ", container, "/", name, "/", type.name(),
|
||||
" does not exist.");
|
||||
return errors::NotFound("Resource ", container, "/", resource_name, "/",
|
||||
type_name, " does not exist.");
|
||||
}
|
||||
*resource = const_cast<ResourceBase*>(iter->second.resource.get());
|
||||
(*resource)->Ref();
|
||||
@ -326,6 +340,12 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
ResourceBase** value) {
|
||||
TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
|
||||
return ctx->resource_manager()->Lookup(p, value);
|
||||
}
|
||||
|
||||
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
|
||||
TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
|
||||
return ctx->resource_manager()->Delete(p);
|
||||
|
@ -181,6 +181,13 @@ class ResourceMgr {
|
||||
Status Lookup(const std::string& container, const std::string& name,
|
||||
T** resource) const TF_MUST_USE_RESULT;
|
||||
|
||||
// If the resource manager has a resource matching "handle", returns it in
|
||||
// "*resource" and the caller takes the ownership of one ref on "*resource".
|
||||
//
|
||||
// REQUIRES: resource != nullptr
|
||||
Status Lookup(const ResourceHandle& handle,
|
||||
ResourceBase** resource) const TF_MUST_USE_RESULT;
|
||||
|
||||
// Similar to Lookup, but looks up multiple resources at once, with only a
|
||||
// single lock acquisition. If containers_and_names[i] is uninitialized
|
||||
// then this function does not modify resources[i].
|
||||
@ -260,6 +267,9 @@ class ResourceMgr {
|
||||
Status LookupInternal(const std::string& container, const std::string& name,
|
||||
T** resource) const
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||
Status LookupInternal(const std::string& container, uint64 type_hash_code,
|
||||
const std::string& name, ResourceBase** resource) const
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||
|
||||
Status DoCreate(const std::string& container, TypeIndex type,
|
||||
const std::string& name, ResourceBase* resource)
|
||||
@ -268,6 +278,11 @@ class ResourceMgr {
|
||||
Status DoLookup(const std::string& container, TypeIndex type,
|
||||
const std::string& name, ResourceBase** resource) const
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||
Status DoLookup(const std::string& container, uint64 type_hash_code,
|
||||
const std::string& type_name,
|
||||
const std::string& resource_name,
|
||||
ResourceBase** resource) const
|
||||
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
|
||||
|
||||
Status DoDelete(const std::string& container, uint64 type_hash_code,
|
||||
const std::string& resource_name,
|
||||
@ -733,12 +748,17 @@ Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
|
||||
|
||||
} // namespace internal
|
||||
|
||||
// Creates the resource pointed at by "p". The caller transfers the ownership of
|
||||
// one ref on "*value" to the resource manager in "ctx", regardless of whether
|
||||
// this operation succeeds or fails.
|
||||
template <typename T>
|
||||
Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
|
||||
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
||||
return ctx->resource_manager()->Create(p.container(), p.name(), value);
|
||||
}
|
||||
|
||||
// If the resource manager in "ctx" has a resource matching "p", returns it in
|
||||
// "*value" and the caller takes the ownership of one ref on "*value"
|
||||
template <typename T, bool use_dynamic_cast>
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
T** value) {
|
||||
@ -747,6 +767,13 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
p.name(), value);
|
||||
}
|
||||
|
||||
// If the resource manager in "ctx" has a resource matching "p", returns it in
|
||||
// "*value" and the caller takes the ownership of one ref on "*value"
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
ResourceBase** value);
|
||||
|
||||
// If the resource manager in "ctx" has a resource matching "p", returns it in
|
||||
// "*value".
|
||||
template <typename T>
|
||||
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
core::RefCountPtr<T>* value) {
|
||||
@ -757,6 +784,8 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Similar to Lookup, but looks up multiple resources at once, with only a
|
||||
// single lock acquisition.
|
||||
template <typename T>
|
||||
Status LookupResources(OpKernelContext* ctx,
|
||||
absl::Span<ResourceHandle const* const> p,
|
||||
@ -770,6 +799,13 @@ Status LookupResources(OpKernelContext* ctx,
|
||||
return ctx->resource_manager()->LookupMany(containers_and_names, values);
|
||||
}
|
||||
|
||||
// If the resource manager in "ctx" has a resource pointed at by "p", returns
|
||||
// it in "*value". Otherwise, invokes creator() to create the resource.
|
||||
// The caller takes the ownership of one ref on "*value".
|
||||
//
|
||||
// WARNING: creator() must not call any methods on the resource manager during
|
||||
// its execution, because a non-reentrant lock is held during the creator() call
|
||||
// in order to guarantee atomicity of LookupOrCreateResource().
|
||||
template <typename T>
|
||||
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
T** value, std::function<Status(T**)> creator) {
|
||||
@ -778,6 +814,12 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
creator);
|
||||
}
|
||||
|
||||
// If the resource manager in "ctx" has a resource pointed at by "p", returns
|
||||
// it in "*value". Otherwise, invokes creator() to create the resource.
|
||||
//
|
||||
// WARNING: creator() must not call any methods on the resource manager during
|
||||
// its execution, because a non-reentrant lock is held during the creator() call
|
||||
// in order to guarantee atomicity of LookupOrCreateResource().
|
||||
template <typename T>
|
||||
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
core::RefCountPtr<T>* value,
|
||||
@ -789,12 +831,14 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Deletes the resource pointed by "p", using the resource manager in "ctx".
|
||||
template <typename T>
|
||||
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
|
||||
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
|
||||
return ctx->resource_manager()->Delete<T>(p.container(), p.name());
|
||||
}
|
||||
|
||||
// Deletes the resource pointed by "p", using the resource manager in "ctx".
|
||||
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
|
||||
|
||||
template <typename T>
|
||||
|
@ -26,6 +26,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/lib/strings/strcat.h"
|
||||
#include "tensorflow/core/platform/refcount.h"
|
||||
#include "tensorflow/core/platform/regexp.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
@ -285,6 +286,36 @@ TEST(ResourceHandleTest, CRUD) {
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ResourceHandleTest, LookupDeleteGenericResource) {
|
||||
ResourceMgr resource_mgr("");
|
||||
OpKernelContext::Params params;
|
||||
params.resource_manager = &resource_mgr;
|
||||
StubDevice device("device_name");
|
||||
params.device = &device;
|
||||
OpKernelContext ctx(¶ms, 0);
|
||||
|
||||
ResourceHandle p =
|
||||
MakeResourceHandle<StubResource>(&ctx, "container", "name");
|
||||
|
||||
{
|
||||
auto* r = new StubResource();
|
||||
r->value_ = 42;
|
||||
TF_EXPECT_OK(CreateResource(&ctx, p, r));
|
||||
}
|
||||
{
|
||||
ResourceBase* r;
|
||||
TF_ASSERT_OK(LookupResource(&ctx, p, &r));
|
||||
ASSERT_TRUE(r != nullptr);
|
||||
core::ScopedUnref unref(r);
|
||||
EXPECT_EQ(static_cast<StubResource*>(r)->value_, 42);
|
||||
}
|
||||
{
|
||||
TF_EXPECT_OK(DeleteResource(&ctx, p));
|
||||
ResourceBase* unused;
|
||||
EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok());
|
||||
}
|
||||
}
|
||||
|
||||
TEST(ResourceHandleTest, DifferentDevice) {
|
||||
ResourceMgr resource_mgr("");
|
||||
OpKernelContext::Params params;
|
||||
|
@ -7240,13 +7240,6 @@ tf_kernel_library(
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "tensor_forest_ops",
|
||||
deps = [
|
||||
"//tensorflow/core/kernels/tensor_forest:tensor_forest_ops",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "data_service_ops",
|
||||
deps = [
|
||||
|
@ -1,141 +0,0 @@
|
||||
# Description:
|
||||
# quantization-specific OpKernels for hexagon
|
||||
|
||||
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
|
||||
load(
|
||||
"//tensorflow:tensorflow.bzl",
|
||||
"tf_cc_test",
|
||||
)
|
||||
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
|
||||
|
||||
package(
|
||||
default_visibility = ["//visibility:public"],
|
||||
licenses = ["notice"], # Apache 2.0
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "graph_transferer_test",
|
||||
size = "small",
|
||||
srcs = [
|
||||
"graph_transferer_test.cc",
|
||||
"hexagon_graph_execution_test.cc",
|
||||
],
|
||||
data = ["//tensorflow/core/example:example_parser_configuration_testdata"],
|
||||
deps = [
|
||||
":graph_transferer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:array_ops_op_lib",
|
||||
"//tensorflow/core:bitwise_ops_op_lib",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:direct_session",
|
||||
"//tensorflow/core:functional_ops_op_lib",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:math_ops_op_lib",
|
||||
"//tensorflow/core:mkl_nn_ops_op_lib",
|
||||
"//tensorflow/core:nn_ops_op_lib",
|
||||
"//tensorflow/core:ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:remote_fused_graph_ops_op_lib",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:cwise_op",
|
||||
"//tensorflow/core/kernels:quantization_utils",
|
||||
"//tensorflow/core/kernels:quantized_ops",
|
||||
"//tensorflow/core/kernels:reduction_ops",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_ops",
|
||||
"//tensorflow/core/kernels:reshape_op",
|
||||
"//tensorflow/core/kernels:softmax_op",
|
||||
"@com_google_absl//absl/base",
|
||||
],
|
||||
)
|
||||
|
||||
tf_kernel_library(
|
||||
name = "graph_transferer",
|
||||
srcs = [
|
||||
"graph_transfer_utils.cc",
|
||||
"graph_transferer.cc",
|
||||
"hexagon_control_wrapper.cc",
|
||||
"hexagon_ops_definitions.cc",
|
||||
"soc_interface.cc",
|
||||
],
|
||||
hdrs = [
|
||||
"graph_transfer_utils.h",
|
||||
"graph_transferer.h",
|
||||
"hexagon_control_wrapper.h",
|
||||
"hexagon_ops_definitions.h",
|
||||
"soc_interface.h",
|
||||
],
|
||||
deps = [
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:remote_fused_graph_ops",
|
||||
"//tensorflow/cc:scope",
|
||||
"//tensorflow/core",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:lib",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
"//third_party/eigen3",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hexagon_rewriter_transform",
|
||||
srcs = [
|
||||
"hexagon_rewriter_transform.cc",
|
||||
],
|
||||
deps = [
|
||||
":graph_transferer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/cc:remote_fused_graph_ops",
|
||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hexagon_rewriter_transform_test",
|
||||
size = "small",
|
||||
srcs = ["hexagon_rewriter_transform_test.cc"],
|
||||
deps = [
|
||||
":graph_transferer",
|
||||
":hexagon_rewriter_transform",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:core_cpu",
|
||||
"//tensorflow/core:core_cpu_internal",
|
||||
"//tensorflow/core:framework",
|
||||
"//tensorflow/core:tensorflow",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
"//tensorflow/tools/graph_transforms:transform_utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "hexagon_remote_fused_graph_executor_build",
|
||||
srcs = [
|
||||
"hexagon_remote_fused_graph_executor_build.cc",
|
||||
],
|
||||
deps = [
|
||||
":graph_transferer",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
],
|
||||
alwayslink = 1,
|
||||
)
|
||||
|
||||
tf_cc_test(
|
||||
name = "hexagon_remote_fused_graph_executor_build_test",
|
||||
size = "small",
|
||||
srcs = ["hexagon_remote_fused_graph_executor_build_test.cc"],
|
||||
deps = [
|
||||
":hexagon_remote_fused_graph_executor_build",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core:testlib",
|
||||
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
|
||||
],
|
||||
)
|
@ -1,167 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
|
||||
|
||||
#include "tensorflow/cc/framework/scope.h"
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
namespace tensorflow {
|
||||
|
||||
// function alias
|
||||
constexpr auto AddOutputTensorShapeTypeByTensorShapeMap =
|
||||
&RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap;
|
||||
|
||||
/* static */ std::priority_queue<std::tuple<float, int, string>>
|
||||
GraphTransferUtils::GetTopNFloatResults(const float* const data,
|
||||
const string* const labels,
|
||||
const int element_count) {
|
||||
CHECK(data != nullptr);
|
||||
CHECK(labels != nullptr);
|
||||
std::priority_queue<std::tuple<float, int, string>> queue;
|
||||
for (int i = 0; i < element_count; ++i) {
|
||||
queue.emplace(data[i], i, labels[i]);
|
||||
}
|
||||
return queue;
|
||||
}
|
||||
|
||||
/* static */ void GraphTransferUtils::DumpTopNFloatResults(
|
||||
const float* const data, const string* const labels,
|
||||
const int element_count, const int top_n) {
|
||||
std::priority_queue<std::tuple<float, int, string>> queue =
|
||||
GetTopNFloatResults(data, labels, element_count);
|
||||
LOG(INFO) << "=== Dump ranking ===";
|
||||
for (int i = 0; i < top_n; ++i) {
|
||||
const std::tuple<float, int, string>& entry = queue.top();
|
||||
LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry)
|
||||
<< ", " << std::get<0>(entry);
|
||||
queue.pop();
|
||||
}
|
||||
}
|
||||
|
||||
/* static */ RemoteFusedGraphExecuteInfo
|
||||
GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
|
||||
const GraphDef& graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs,
|
||||
const RemoteFusedGraphExecuteUtils::TensorShapeMap& tensor_shape_map) {
|
||||
RemoteFusedGraphExecuteInfo execute_info;
|
||||
execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
|
||||
|
||||
// copy graph
|
||||
*execute_info.mutable_remote_graph() = graph_def;
|
||||
|
||||
for (const std::pair<string, Tensor>& input : inputs) {
|
||||
execute_info.add_graph_input_node_name(input.first);
|
||||
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
|
||||
*execute_info.add_default_graph_input_tensor_shape();
|
||||
tensor_shape_type.set_dtype(input.second.dtype());
|
||||
TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
|
||||
for (const int64 dim : input.second.shape().dim_sizes()) {
|
||||
tensor_shape_proto.add_dim()->set_size(dim);
|
||||
}
|
||||
}
|
||||
|
||||
for (const string& output_name : outputs) {
|
||||
const std::pair<DataType, TensorShape>* tensor_shape_type =
|
||||
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
|
||||
output_name);
|
||||
CHECK_NOTNULL(tensor_shape_type);
|
||||
execute_info.add_graph_output_node_name(output_name);
|
||||
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type_proto =
|
||||
*execute_info.add_default_graph_output_tensor_shape();
|
||||
tensor_shape_type_proto.set_dtype(tensor_shape_type->first);
|
||||
TensorShapeProto& tensor_shape_proto =
|
||||
*tensor_shape_type_proto.mutable_shape();
|
||||
for (const int64 dim : tensor_shape_type->second.dim_sizes()) {
|
||||
tensor_shape_proto.add_dim()->set_size(dim);
|
||||
}
|
||||
}
|
||||
|
||||
return execute_info;
|
||||
}
|
||||
|
||||
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const string& remote_graph_execute_name,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, GraphDef* original_def) {
|
||||
RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
|
||||
Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
|
||||
*original_def, inputs, true /* initialize_by_zero */, &tensor_shape_map);
|
||||
for (NodeDef& node_def : *original_def->mutable_node()) {
|
||||
TF_CHECK_OK(
|
||||
AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
|
||||
}
|
||||
CHECK(status.ok());
|
||||
|
||||
Scope root = Scope::NewRootScope();
|
||||
std::vector<Output> output_list;
|
||||
DataTypeVector input_types;
|
||||
for (const std::pair<string, Tensor>& input_node_info : inputs) {
|
||||
const Scope& scope = root.WithOpName(input_node_info.first);
|
||||
Node* ret;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("Placeholder");
|
||||
auto builder = NodeBuilder(unique_name, "Placeholder")
|
||||
.Attr("dtype", input_node_info.second.dtype())
|
||||
.Attr("shape", input_node_info.second.shape());
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
TF_CHECK_OK(scope.status());
|
||||
output_list.emplace_back(Output(ret, 0));
|
||||
input_types.push_back(input_node_info.second.dtype());
|
||||
}
|
||||
|
||||
const RemoteFusedGraphExecuteInfo execute_info =
|
||||
BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs,
|
||||
tensor_shape_map);
|
||||
|
||||
DataTypeVector output_types;
|
||||
// Sanity-check to confirm all output data types are same.
|
||||
for (const string& output_node_name : outputs) {
|
||||
const std::pair<DataType, TensorShape>* tst =
|
||||
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
|
||||
output_node_name);
|
||||
CHECK_NE(tst, nullptr);
|
||||
output_types.push_back(tst->first);
|
||||
}
|
||||
|
||||
const Scope& scope = root.WithOpName(remote_graph_execute_name);
|
||||
CHECK(scope.ok());
|
||||
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
|
||||
Node* node;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
|
||||
|
||||
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
|
||||
.Input(node_out_list)
|
||||
.Attr("Tinputs", input_types)
|
||||
.Attr("Toutputs", output_types)
|
||||
.Attr("serialized_remote_fused_graph_execute_info",
|
||||
StringPiece(execute_info.SerializeAsString()));
|
||||
CHECK(scope.ok());
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &node));
|
||||
CHECK(scope.ok()) << scope.status();
|
||||
|
||||
GraphDef fusedGraphDef;
|
||||
TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef));
|
||||
return fusedGraphDef;
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,59 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
|
||||
|
||||
#include <queue>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class RemoteFusedGraphExecuteInfo;
|
||||
|
||||
class GraphTransferUtils {
|
||||
public:
|
||||
static std::priority_queue<std::tuple<float, int, string>>
|
||||
GetTopNFloatResults(const float* const data, const string* const labels,
|
||||
const int element_count);
|
||||
|
||||
static void DumpTopNFloatResults(const float* const data,
|
||||
const string* const labels,
|
||||
const int element_count, const int top_n);
|
||||
|
||||
static GraphDef BuildFusedGraphDef(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const string& remote_graph_execute_name,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs, GraphDef* original_def);
|
||||
|
||||
private:
|
||||
static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
|
||||
const GraphDef& graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& inputs,
|
||||
const std::vector<string>& outputs,
|
||||
const RemoteFusedGraphExecuteUtils::TensorShapeMap& tensor_shape_map);
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferUtils);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
|
File diff suppressed because it is too large
Load Diff
@ -1,231 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
vcyou may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
|
||||
|
||||
#include <array>
|
||||
#include <unordered_map>
|
||||
#include <unordered_set>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/shape_inference.h"
|
||||
#include "tensorflow/core/graph/graph.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/util/padding.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
class GraphTransferInfo;
|
||||
class GraphTransferNodeInfo;
|
||||
class GraphTransferNodeInputInfo;
|
||||
|
||||
// GraphTransferer transfers graph definitions into SoC memory.
|
||||
// This functionality is effective if SoC is capable to run
|
||||
// the graph on that chip.
|
||||
// TODO(satok): support transferring subgraphs to be able to split graphs
|
||||
// to avoid unsupported ops in SoC.
|
||||
class GraphTransferer {
|
||||
public:
|
||||
// TODO(satok): Remove. Use proto definition instead.
|
||||
static constexpr int MAX_SUPPORTED_RANK = 4;
|
||||
// TODO(satok): Remove. Use proto definition instead.
|
||||
static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
|
||||
using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
|
||||
|
||||
GraphTransferer();
|
||||
|
||||
~GraphTransferer();
|
||||
|
||||
// Load graph structure into GraphTransferer
|
||||
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
||||
// Tensor as input_node_info_list.
|
||||
Status LoadGraphFromProto(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const GraphDef& graph_def,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names,
|
||||
const bool shape_inference_for_unknown_shape);
|
||||
|
||||
// Load graph structure into GraphTransferer from protobuf file
|
||||
// TODO(satok): Pass a pair of TensorShape and DataType instead of
|
||||
// Tensor as input_node_info_list.
|
||||
Status LoadGraphFromProtoFile(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const string& graph_def_path,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names, const bool is_text_proto,
|
||||
const bool shape_inference_for_unknown_shape,
|
||||
const bool dry_run_for_unknown_shape);
|
||||
|
||||
// Sort params so that all input nodes appear before consumer nodes.
|
||||
// CAVEAT: This may be slow if the number of nodes are too large
|
||||
void SortParams(const std::vector<string>& output_node_names);
|
||||
|
||||
void EnableStrictCheckMode(bool enable);
|
||||
|
||||
// Import parameters for transfer
|
||||
void SetSerializedGraphTransferInfo(const string& serialized_proto);
|
||||
|
||||
// Return parameters for graph transfer
|
||||
const GraphTransferInfo& GetGraphTransferInfo() const;
|
||||
|
||||
// Return mutable GraphTransferInfo for graph transfer
|
||||
GraphTransferInfo& GetMutableGraphTransferInfo();
|
||||
|
||||
// Dump verification string of parameters to verify with offline tools
|
||||
void DumpVerificationStringOfNodeTransferParams() const;
|
||||
|
||||
static std::array<int64, SHAPE_ARRAY_SIZE> ToTensorShapeArray(
|
||||
const TensorShape& shape);
|
||||
|
||||
private:
|
||||
class TransferParamsComparator {
|
||||
public:
|
||||
TransferParamsComparator(
|
||||
const std::unordered_map<int, std::unordered_set<int>>& dep_map);
|
||||
bool operator()(const GraphTransferNodeInfo& obj0,
|
||||
const GraphTransferNodeInfo& obj1);
|
||||
const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
|
||||
};
|
||||
|
||||
void CacheNode(const Node& node);
|
||||
|
||||
bool AreAllInputsCached(const Node& node) const;
|
||||
|
||||
// Transform a remote fused graph to add an aggregated input node which takes
|
||||
// all inputs of the remote graph.
|
||||
Status TransformGraphToAddAggregatedInputNode(
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
Graph* graph, ShapeRefiner* shape_refiner);
|
||||
|
||||
Status RegisterNode(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names);
|
||||
|
||||
void RegisterConstantNode(const ShapeRefiner& shape_refiner,
|
||||
const Node& node);
|
||||
|
||||
int RegisterConstantShape(const std::vector<int>& shape);
|
||||
|
||||
int RegisterConstTensor(const Tensor& tensor, const string& suffix);
|
||||
|
||||
int RegisterConstScalar(const DataType dt, const int val, const int dst_id,
|
||||
const int dst_input_count);
|
||||
|
||||
bool HasPaddingAndStrides(const Node& node);
|
||||
|
||||
bool NeedsToAddRank(const Node& node);
|
||||
|
||||
bool IsPadNode(const Node& node);
|
||||
|
||||
// Return true if the node is a reshape op which just flattens input
|
||||
// TODO(satok): Remove this method once generic reshape op is implemented in
|
||||
// SOC
|
||||
bool IsNodeFlattenReshape(const Node& node,
|
||||
const ShapeRefiner& shape_refiner);
|
||||
|
||||
void RegisterNodeWithPaddingAndStrides(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterNodeWithRank(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterFlattenNode(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
void RegisterGenericNode(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node);
|
||||
|
||||
Status RegisterNodeIfAllInputsAreCached(
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
|
||||
const ShapeRefiner& shape_refiner, const Node& node,
|
||||
const bool only_register_const_node,
|
||||
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
|
||||
const std::vector<string>& output_node_names);
|
||||
|
||||
void AppendNodeParams(const string& name, const int id, const string& type,
|
||||
const int type_id, const int padding,
|
||||
const int inputs_size,
|
||||
const std::vector<int>& extra_inputs,
|
||||
const int outputs_size);
|
||||
|
||||
void AddNodeInputByInputIndex(const Node& node, const int idx,
|
||||
GraphTransferNodeInputInfo* node_input_info);
|
||||
|
||||
void AppendNodeInputParams(const int id, const Node& node,
|
||||
const std::vector<int>& extra_inputs);
|
||||
|
||||
void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id,
|
||||
const Node& node);
|
||||
|
||||
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
|
||||
const shape_inference::ShapeHandle& shape_handle,
|
||||
shape_inference::InferenceContext* context);
|
||||
|
||||
void AppendNodeParamsWithIoParams(
|
||||
const ShapeRefiner& shape_refiner, const Node& node, const string& name,
|
||||
const int id, const string& type, const int type_id, const int padding,
|
||||
const int inputs_size, const std::vector<int>& extra_inputs,
|
||||
const int outputs_size, const bool append_input_params,
|
||||
const bool append_output_params);
|
||||
|
||||
static string ToPaddingDebugString(int padding);
|
||||
|
||||
// Create dependency map
|
||||
static void FillDependencyRec(
|
||||
int node_id, std::unordered_map<int, std::unordered_set<int>>& dep_map,
|
||||
std::unordered_set<int>& completed);
|
||||
|
||||
// Build tensor from proto
|
||||
static Status MakeTensorFromProto(const TensorProto& tensor_proto,
|
||||
Tensor* tensor);
|
||||
|
||||
void ClearCache();
|
||||
|
||||
// Dump pretty print of parameters
|
||||
void DumpNodeTransferParams() const;
|
||||
|
||||
GraphTransferInfo* graph_transfer_info_;
|
||||
|
||||
std::vector<const Node*> node_name_cache_list_{};
|
||||
std::unordered_map<string, int> node_name_to_id_cache_map_{};
|
||||
|
||||
// strict check mode is true by default. Disable this if the ops' shape
|
||||
// inferences are not implemented correctly.
|
||||
bool strict_check_mode_{true};
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferer);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
|
@ -1,481 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "tensorflow/cc/ops/const_op.h"
|
||||
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/graph/graph_def_builder.h"
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
const string NAME_A = "a";
|
||||
const string NAME_B = "b";
|
||||
const string NAME_A_PLUS_B = "a_plus_b";
|
||||
constexpr float NODE_A_VAL = 2.0f;
|
||||
constexpr float NODE_B_VAL = 3.0f;
|
||||
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
|
||||
|
||||
class GraphTransfererTest : public ::testing::Test {
|
||||
protected:
|
||||
void SetUp() final {}
|
||||
|
||||
GraphTransferer gt_;
|
||||
};
|
||||
|
||||
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
|
||||
|
||||
class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
|
||||
public:
|
||||
int GetTotalOpsCount() const final { return op_types_.size(); }
|
||||
|
||||
int GetOpIdFor(const string& op_type, const DataTypeVector&) const final {
|
||||
for (int i = 0; i < op_types_.size(); ++i) {
|
||||
if (op_types_[i] == op_type) {
|
||||
return i;
|
||||
}
|
||||
}
|
||||
return -1;
|
||||
}
|
||||
|
||||
private:
|
||||
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
|
||||
"MaxPool", "NoOp", "Add",
|
||||
"Const", "Softmax", "Identity"};
|
||||
} TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
||||
|
||||
static Output BuildAddOps(const Scope& scope, const Input& x, const Input& y) {
|
||||
EXPECT_TRUE(scope.ok());
|
||||
auto _x = ops::AsNodeOut(scope, x);
|
||||
EXPECT_TRUE(scope.ok());
|
||||
auto _y = ops::AsNodeOut(scope, y);
|
||||
EXPECT_TRUE(scope.ok());
|
||||
Node* ret;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("Add");
|
||||
auto builder = NodeBuilder(unique_name, "Add").Input(_x).Input(_y);
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
EXPECT_TRUE(scope.ok());
|
||||
return Output(ret, 0);
|
||||
}
|
||||
|
||||
static Output BuildSoftmaxOps(const Scope& scope, const Input& logits) {
|
||||
EXPECT_TRUE(scope.ok());
|
||||
auto _logits = ops::AsNodeOut(scope, logits);
|
||||
EXPECT_TRUE(scope.ok());
|
||||
Node* ret;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("Softmax");
|
||||
auto builder = NodeBuilder(unique_name, "Softmax").Input(_logits);
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
EXPECT_TRUE(scope.ok());
|
||||
return Output(ret, 0);
|
||||
}
|
||||
|
||||
static Output BuildConv2DOps(const Scope& scope, const Input& input,
|
||||
const Input& filter,
|
||||
const gtl::ArraySlice<int>& strides,
|
||||
const StringPiece& padding) {
|
||||
EXPECT_TRUE(scope.ok());
|
||||
auto _input = ops::AsNodeOut(scope, input);
|
||||
EXPECT_TRUE(scope.ok());
|
||||
auto _filter = ops::AsNodeOut(scope, filter);
|
||||
EXPECT_TRUE(scope.ok());
|
||||
Node* ret;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("Conv2D");
|
||||
auto builder = NodeBuilder(unique_name, "Conv2D")
|
||||
.Input(_input)
|
||||
.Input(_filter)
|
||||
.Attr("strides", strides)
|
||||
.Attr("use_cudnn_on_gpu", true)
|
||||
.Attr("padding", padding)
|
||||
.Attr("data_format", "NHWC");
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
EXPECT_TRUE(scope.ok());
|
||||
return Output(ret, 0);
|
||||
}
|
||||
|
||||
static Output BuildMaxPoolOps(const Scope& scope, const Input& input,
|
||||
const gtl::ArraySlice<int>& ksize,
|
||||
const gtl::ArraySlice<int>& strides,
|
||||
const StringPiece& padding) {
|
||||
EXPECT_TRUE(scope.ok());
|
||||
auto _input = ops::AsNodeOut(scope, input);
|
||||
EXPECT_TRUE(scope.ok());
|
||||
Node* ret;
|
||||
const auto unique_name = scope.GetUniqueNameForOp("MaxPool");
|
||||
auto builder = NodeBuilder(unique_name, "MaxPool")
|
||||
.Input(_input)
|
||||
.Attr("ksize", ksize)
|
||||
.Attr("strides", strides)
|
||||
.Attr("padding", padding)
|
||||
.Attr("data_format", "NHWC");
|
||||
scope.UpdateBuilder(&builder);
|
||||
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
|
||||
EXPECT_TRUE(scope.ok());
|
||||
return Output(ret, 0);
|
||||
}
|
||||
|
||||
static GraphDef CreateAddGraphDef() {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Output node_a = ops::Const(root.WithOpName(NAME_A), NODE_A_VAL);
|
||||
Output node_b = ops::Const(root.WithOpName(NAME_B), NODE_B_VAL);
|
||||
Output node_add = BuildAddOps(root.WithOpName(NAME_A_PLUS_B), node_a, node_b);
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(root.ToGraphDef(&def));
|
||||
return def;
|
||||
}
|
||||
|
||||
static GraphDef CreateConvGraphDef() {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
|
||||
test::FillIota<float>(&input_data, 1.0f);
|
||||
Output input =
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||
Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
|
||||
test::FillIota<float>(&filter_data, 1.0f);
|
||||
Output filter =
|
||||
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data));
|
||||
const std::vector<int> strides{1, 1, 1, 1};
|
||||
Output conv =
|
||||
BuildConv2DOps(root.WithOpName("conv"), input, filter, strides, "SAME");
|
||||
Output softmax = BuildSoftmaxOps(root.WithOpName("softmax"), conv);
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(root.ToGraphDef(&def));
|
||||
return def;
|
||||
}
|
||||
|
||||
static GraphDef CreatePoolGraphDef() {
|
||||
Scope root = Scope::NewRootScope();
|
||||
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
|
||||
test::FillIota<float>(&input_data, 1.0f);
|
||||
Output input =
|
||||
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
|
||||
Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
|
||||
test::FillIota<float>(&filter_data, 1.0f);
|
||||
Output filter =
|
||||
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data));
|
||||
const std::vector<int> ksize{1, 1, 1, 1};
|
||||
const std::vector<int> padding{0, 0, 0, 0};
|
||||
const std::vector<int> strides{1, 1, 1, 1};
|
||||
Output max_pool = BuildMaxPoolOps(root.WithOpName("maxpool"), input, ksize,
|
||||
strides, "SAME");
|
||||
Output softmax = BuildSoftmaxOps(root.WithOpName("softmax"), max_pool);
|
||||
GraphDef def;
|
||||
TF_CHECK_OK(root.ToGraphDef(&def));
|
||||
return def;
|
||||
}
|
||||
|
||||
static const GraphTransferConstNodeInfo* FindConstNodeInfo(
|
||||
const GraphTransferer& gt, const string& name) {
|
||||
for (const GraphTransferConstNodeInfo& params :
|
||||
gt.GetGraphTransferInfo().const_node_info()) {
|
||||
if (params.name() == name) {
|
||||
return ¶ms;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const GraphTransferNodeInfo* FindNodeInfo(const GraphTransferer& gt,
|
||||
const string& name) {
|
||||
for (const GraphTransferNodeInfo& params :
|
||||
gt.GetGraphTransferInfo().node_info()) {
|
||||
if (params.name() == name) {
|
||||
return ¶ms;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const GraphTransferNodeInputInfo* FindNodeInputInfo(
|
||||
const GraphTransferer& gt, const int node_id) {
|
||||
for (const GraphTransferNodeInputInfo& params :
|
||||
gt.GetGraphTransferInfo().node_input_info()) {
|
||||
if (params.node_id() == node_id) {
|
||||
return ¶ms;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static const GraphTransferNodeOutputInfo* FindNodeOutputInfo(
|
||||
const GraphTransferer& gt, const int node_id) {
|
||||
for (const GraphTransferNodeOutputInfo& params :
|
||||
gt.GetGraphTransferInfo().node_output_info()) {
|
||||
if (params.node_id() == node_id) {
|
||||
return ¶ms;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
static void SanityCheckNodes(const GraphTransferer& gt) {
|
||||
for (const GraphTransferNodeInfo& params :
|
||||
gt.GetGraphTransferInfo().node_info()) {
|
||||
if (params.input_count() > 0) {
|
||||
const GraphTransferNodeInputInfo* input_params =
|
||||
FindNodeInputInfo(gt, params.node_id());
|
||||
ASSERT_NE(nullptr, input_params);
|
||||
EXPECT_EQ(params.input_count(), input_params->node_input_size());
|
||||
EXPECT_EQ(params.node_id(), input_params->node_id());
|
||||
for (const GraphTransferNodeInput& node_input :
|
||||
input_params->node_input()) {
|
||||
EXPECT_GE(node_input.output_port(), 0);
|
||||
}
|
||||
}
|
||||
if (params.output_count() > 0) {
|
||||
const GraphTransferNodeOutputInfo* output_params =
|
||||
FindNodeOutputInfo(gt, params.node_id());
|
||||
ASSERT_NE(nullptr, output_params);
|
||||
EXPECT_EQ(params.output_count(), output_params->max_byte_size_size());
|
||||
EXPECT_EQ(params.node_id(), output_params->node_id());
|
||||
for (const int max_size : output_params->max_byte_size()) {
|
||||
EXPECT_GE(max_size, 0);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
TEST_F(GraphTransfererTest, LoadAddGraph) {
|
||||
GraphDef def = CreateAddGraphDef();
|
||||
ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
|
||||
{}, std::vector<string>{NAME_A_PLUS_B},
|
||||
false)
|
||||
.ok());
|
||||
SanityCheckNodes(gt_);
|
||||
|
||||
const int const_node_count =
|
||||
gt_.GetGraphTransferInfo().const_node_info_size();
|
||||
ASSERT_EQ(2, const_node_count);
|
||||
const GraphTransferConstNodeInfo* params_a = FindConstNodeInfo(gt_, NAME_A);
|
||||
ASSERT_TRUE(params_a != nullptr);
|
||||
EXPECT_EQ(NAME_A, params_a->name());
|
||||
ASSERT_EQ(4, params_a->shape_size());
|
||||
EXPECT_EQ(1, params_a->shape(0));
|
||||
EXPECT_EQ(1, params_a->shape(1));
|
||||
EXPECT_EQ(1, params_a->shape(2));
|
||||
EXPECT_EQ(1, params_a->shape(3));
|
||||
EXPECT_EQ(4, params_a->data().length());
|
||||
|
||||
const GraphTransferConstNodeInfo* params_b = FindConstNodeInfo(gt_, NAME_B);
|
||||
ASSERT_TRUE(params_b != nullptr);
|
||||
ASSERT_EQ(4, params_b->shape_size());
|
||||
EXPECT_EQ(1, params_b->shape(0));
|
||||
EXPECT_EQ(1, params_b->shape(1));
|
||||
EXPECT_EQ(1, params_b->shape(2));
|
||||
EXPECT_EQ(1, params_b->shape(3));
|
||||
EXPECT_EQ(4, params_b->data().length());
|
||||
}
|
||||
|
||||
TEST_F(GraphTransfererTest, LoadAddGraphWithOutputTensorMap) {
|
||||
GraphDef def = CreateAddGraphDef();
|
||||
std::pair<string, Tensor> input_node_info_a;
|
||||
input_node_info_a.first = NAME_A;
|
||||
input_node_info_a.second = Tensor(DT_FLOAT, {});
|
||||
input_node_info_a.second.scalar<float>()() = 1.0f;
|
||||
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a};
|
||||
RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info;
|
||||
Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
|
||||
def, inputs, {}, &output_tensor_info);
|
||||
ASSERT_TRUE(status.ok()) << status;
|
||||
const std::vector<string> output_node_names = {NAME_A_PLUS_B};
|
||||
status = gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
|
||||
inputs, output_node_names, false);
|
||||
TF_ASSERT_OK(status);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransfererTest, LoadConvGraph) {
|
||||
GraphDef def = CreateConvGraphDef();
|
||||
std::vector<std::pair<string, Tensor>> input_node_info_list;
|
||||
input_node_info_list.emplace_back(
|
||||
std::pair<string, Tensor>{"input", Tensor{DT_FLOAT, {1, 1, 1, 1}}});
|
||||
const std::vector<string> output_node_names = {"softmax"};
|
||||
ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
|
||||
input_node_info_list, output_node_names,
|
||||
false)
|
||||
.ok());
|
||||
SanityCheckNodes(gt_);
|
||||
const int const_node_count =
|
||||
gt_.GetGraphTransferInfo().const_node_info_size();
|
||||
ASSERT_EQ(2, const_node_count);
|
||||
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
|
||||
ASSERT_EQ(4, op_node_count);
|
||||
const GraphTransferNodeInfo* params_conv = FindNodeInfo(gt_, "conv");
|
||||
ASSERT_TRUE(params_conv != nullptr);
|
||||
const int id = params_conv->node_id();
|
||||
EXPECT_GE(id, 0);
|
||||
EXPECT_EQ("Conv2D", params_conv->type_name());
|
||||
EXPECT_EQ(3, params_conv->input_count());
|
||||
EXPECT_EQ(1, params_conv->output_count());
|
||||
EXPECT_EQ(Padding::SAME, params_conv->padding_id());
|
||||
}
|
||||
|
||||
TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
|
||||
GraphDef def = CreatePoolGraphDef();
|
||||
std::vector<std::pair<string, Tensor>> input_node_info_list;
|
||||
input_node_info_list.emplace_back(
|
||||
std::pair<string, Tensor>{"input", Tensor{DT_FLOAT, {1, 1, 1, 1}}});
|
||||
const std::vector<string> output_node_names = {"softmax"};
|
||||
ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
|
||||
input_node_info_list, output_node_names,
|
||||
false)
|
||||
.ok());
|
||||
SanityCheckNodes(gt_);
|
||||
const int const_node_count =
|
||||
gt_.GetGraphTransferInfo().const_node_info_size();
|
||||
ASSERT_EQ(2, const_node_count);
|
||||
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
|
||||
ASSERT_EQ(4, op_node_count);
|
||||
const GraphTransferNodeInfo* params_max_pool = FindNodeInfo(gt_, "maxpool");
|
||||
ASSERT_TRUE(params_max_pool != nullptr);
|
||||
const int id = params_max_pool->node_id();
|
||||
EXPECT_GE(id, 0);
|
||||
EXPECT_EQ("MaxPool", params_max_pool->type_name());
|
||||
EXPECT_EQ(3, params_max_pool->input_count());
|
||||
EXPECT_EQ(1, params_max_pool->output_count());
|
||||
EXPECT_EQ(Padding::SAME, params_max_pool->padding_id());
|
||||
}
|
||||
|
||||
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
|
||||
const IRemoteFusedGraphOpsDefinitions& ops_definitions =
|
||||
HexagonOpsDefinitions::getInstance();
|
||||
const int total_ops_count = ops_definitions.GetTotalOpsCount();
|
||||
EXPECT_GT(total_ops_count, 0);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer, LoadGraphFromProtoFile) {
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
||||
string filename =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
"core/example/testdata/parse_example_graph_def.pbtxt");
|
||||
std::vector<std::pair<string, Tensor>> input_node_info_list = {};
|
||||
std::vector<string> output_node_names = {};
|
||||
bool is_text_proto = true;
|
||||
|
||||
// Keep following comments for debugging purpose for now
|
||||
// filename = "v3_stripped_quantized_graph_opt.pb";
|
||||
// input_node_info_list.emplace_back(
|
||||
// std::pair<string, Tensor>{"Mul", Tensor{DT_FLOAT, {1,299,299,3}}});
|
||||
// output_node_names.emplace_back("softmax");
|
||||
// is_text_proto = false;
|
||||
// ops_definitions = &HexagonOpsDefinitions::getInstance();
|
||||
|
||||
GraphTransferer gt;
|
||||
gt.EnableStrictCheckMode(false);
|
||||
Status status = gt.LoadGraphFromProtoFile(
|
||||
*ops_definitions, filename, input_node_info_list, output_node_names,
|
||||
is_text_proto, false, true);
|
||||
}
|
||||
|
||||
TEST_F(GraphTransfererTest, BuildRemoteFusedGraphDefAddGraph) {
|
||||
GraphDef def = CreateAddGraphDef();
|
||||
std::pair<string, Tensor> input_node_info_a;
|
||||
input_node_info_a.first = NAME_A;
|
||||
input_node_info_a.second = Tensor(DT_FLOAT, {});
|
||||
input_node_info_a.second.scalar<float>()() = 1.0f;
|
||||
std::pair<string, Tensor> input_node_info_b;
|
||||
input_node_info_b.first = NAME_B;
|
||||
input_node_info_b.second = Tensor(DT_FLOAT, {});
|
||||
input_node_info_b.second.scalar<float>()() = 10.0f;
|
||||
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a,
|
||||
input_node_info_b};
|
||||
std::vector<string> outputs = {NAME_A_PLUS_B};
|
||||
|
||||
GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef(
|
||||
TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, "remote_fused_graph_execute_node",
|
||||
inputs, outputs, &def);
|
||||
|
||||
EXPECT_EQ(3, fused_graph_def.node_size());
|
||||
}
|
||||
|
||||
namespace {
|
||||
// Just compares the max_byte_size attributes present.
|
||||
void CompareGraphTransferInfo(const GraphTransferInfo& a,
|
||||
const GraphTransferInfo& b) {
|
||||
EXPECT_EQ(a.node_output_info_size(), b.node_output_info_size());
|
||||
for (int i = 0; i < a.node_output_info_size(); ++i) {
|
||||
EXPECT_EQ(a.node_output_info(i).node_id(), b.node_output_info(i).node_id());
|
||||
EXPECT_EQ(a.node_output_info(i).max_byte_size_size(),
|
||||
b.node_output_info(i).max_byte_size_size());
|
||||
for (int j = 0; j < a.node_output_info(i).max_byte_size_size(); ++j) {
|
||||
EXPECT_EQ(a.node_output_info(i).max_byte_size(j),
|
||||
b.node_output_info(i).max_byte_size(j));
|
||||
}
|
||||
}
|
||||
}
|
||||
} // anonymous namespace
|
||||
|
||||
TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) {
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
|
||||
string filename =
|
||||
io::JoinPath(testing::TensorFlowSrcRoot(),
|
||||
"core/example/testdata/parse_example_graph_def.pbtxt");
|
||||
std::vector<std::pair<string, Tensor>> input_node_info_list = {};
|
||||
std::vector<string> output_node_names = {};
|
||||
bool is_text_proto = true;
|
||||
|
||||
// In order to run with a more complex graph uncomment the following lines
|
||||
// filename = "v3_stripped_quantized_graph_opt.pb";
|
||||
// input_node_info_list.emplace_back(
|
||||
// std::pair<string, Tensor>{"Mul", Tensor{DT_FLOAT, {1,299,299,3}}});
|
||||
// output_node_names.emplace_back("softmax");
|
||||
// is_text_proto = false;
|
||||
// ops_definitions = &HexagonOpsDefinitions::getInstance();
|
||||
|
||||
// First compute using Shape inference.
|
||||
GraphTransferer si_gt;
|
||||
si_gt.EnableStrictCheckMode(false);
|
||||
bool shape_inference_for_unknown_shape = true;
|
||||
bool dry_run_for_unknown_shape = false;
|
||||
Status status1 = si_gt.LoadGraphFromProtoFile(
|
||||
*ops_definitions, filename, input_node_info_list, output_node_names,
|
||||
is_text_proto, shape_inference_for_unknown_shape,
|
||||
dry_run_for_unknown_shape);
|
||||
const GraphTransferInfo& si_graph_transfer_info =
|
||||
si_gt.GetGraphTransferInfo();
|
||||
|
||||
// Now compute using dry run.
|
||||
GraphTransferer dr_gt;
|
||||
dr_gt.EnableStrictCheckMode(false);
|
||||
shape_inference_for_unknown_shape = false;
|
||||
dry_run_for_unknown_shape = true;
|
||||
Status status2 = dr_gt.LoadGraphFromProtoFile(
|
||||
*ops_definitions, filename, input_node_info_list, output_node_names,
|
||||
is_text_proto, shape_inference_for_unknown_shape,
|
||||
dry_run_for_unknown_shape);
|
||||
const GraphTransferInfo& dr_graph_transfer_info =
|
||||
dr_gt.GetGraphTransferInfo();
|
||||
|
||||
// Now compare both of them.
|
||||
CompareGraphTransferInfo(si_graph_transfer_info, dr_graph_transfer_info);
|
||||
}
|
||||
|
||||
} // namespace tensorflow
|
@ -1,437 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
vcyou may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||
|
||||
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
||||
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
|
||||
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
|
||||
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX =
|
||||
"hexagon_remote_fused_graph";
|
||||
/* static */ constexpr const char* const
|
||||
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
|
||||
|
||||
constexpr int ALIGNMENT_BYTES = 16;
|
||||
constexpr int MAX_IN_OUT_COUNT = 128;
|
||||
|
||||
const bool DBG_DUMP_VERIFICATION_STRING = false;
|
||||
const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info
|
||||
const bool DBG_USE_DUMMY_INPUT = false;
|
||||
const bool DBG_USE_SAMPLE_INPUT = false;
|
||||
const int64 FLAG_ENABLE_PANDA_BINARY_INPUT = 0x01;
|
||||
const bool DBG_DUMP_INPUT_TENSOR_AS_FLOAT_DATA = false;
|
||||
|
||||
static string AddPort(const string& node_name) {
|
||||
if (node_name.find(':') != string::npos) {
|
||||
return node_name;
|
||||
} else {
|
||||
return strings::StrCat(node_name, ":", 0);
|
||||
}
|
||||
}
|
||||
|
||||
static uint8* FindAlignedPointer(uint8* ptr) {
|
||||
const uintptr_t data_ptr_int = reinterpret_cast<uintptr_t>(ptr);
|
||||
const int shift_count =
|
||||
(ALIGNMENT_BYTES - data_ptr_int % ALIGNMENT_BYTES) % ALIGNMENT_BYTES;
|
||||
uint8* data_ptr = ptr + shift_count;
|
||||
return data_ptr;
|
||||
}
|
||||
|
||||
/* static */ GraphTransferNodeInfo* HexagonControlWrapper::FindNodeInfo(
|
||||
const string& name, GraphTransferInfo* graph_transfer_info) {
|
||||
for (GraphTransferNodeInfo& node_info :
|
||||
*graph_transfer_info->mutable_node_info()) {
|
||||
if (node_info.name() == name) {
|
||||
return &node_info;
|
||||
}
|
||||
}
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
int HexagonControlWrapper::GetVersion() {
|
||||
return soc_interface_GetSocControllerVersion();
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
|
||||
soc_interface_SetLogLevel(DBG_LEVEL);
|
||||
if (DBG_USE_SAMPLE_INPUT) {
|
||||
soc_interface_SetDebugFlag(FLAG_ENABLE_PANDA_BINARY_INPUT);
|
||||
}
|
||||
if (info.serialized_executor_parameters().empty()) {
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<string> outputs;
|
||||
RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
|
||||
info, &inputs, &outputs);
|
||||
Status status = graph_transferer_.LoadGraphFromProto(
|
||||
HexagonOpsDefinitions::getInstance(), info.remote_graph(), inputs,
|
||||
outputs,
|
||||
false // shape_inference_for_unknown_shape
|
||||
);
|
||||
TF_CHECK_OK(status) << status;
|
||||
} else {
|
||||
// If graph transfer info is attached, just import it.
|
||||
graph_transferer_.SetSerializedGraphTransferInfo(
|
||||
info.serialized_executor_parameters());
|
||||
}
|
||||
execute_info_ = &info;
|
||||
bool success = soc_interface_Init();
|
||||
if (!success) {
|
||||
LOG(ERROR) << "Hexagon initialization was failed. See log output.";
|
||||
return false;
|
||||
}
|
||||
std::vector<int> input_sizes;
|
||||
std::vector<int> output_sizes;
|
||||
CHECK_NOTNULL(execute_info_);
|
||||
for (int i = 0; i < execute_info_->graph_input_node_name_size(); ++i) {
|
||||
const string& input = execute_info_->graph_input_node_name(i);
|
||||
LOG(INFO) << "Add input: " << input << ", " << i;
|
||||
CHECK(input_port_map_.emplace(AddPort(input), i).second);
|
||||
const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type =
|
||||
execute_info_->default_graph_input_tensor_shape(i);
|
||||
int64 buf_size = DataTypeSize(shape_type.dtype());
|
||||
for (const TensorShapeProto::Dim& dim : shape_type.shape().dim()) {
|
||||
buf_size *= dim.size();
|
||||
}
|
||||
input_sizes.emplace_back(static_cast<int>(buf_size));
|
||||
}
|
||||
for (int i = 0; i < execute_info_->graph_output_node_name_size(); ++i) {
|
||||
const string& output = execute_info_->graph_output_node_name(i);
|
||||
CHECK(output_port_map_.emplace(AddPort(output), i).second);
|
||||
const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type =
|
||||
execute_info_->default_graph_output_tensor_shape(i);
|
||||
|
||||
int64 buf_size = DataTypeSize(shape_type.dtype());
|
||||
for (const TensorShapeProto::Dim& dim : shape_type.shape().dim()) {
|
||||
buf_size *= dim.size();
|
||||
}
|
||||
output_sizes.emplace_back(static_cast<int>(buf_size));
|
||||
}
|
||||
|
||||
LOG(INFO) << "Allocate inout buffer";
|
||||
success &= soc_interface_AllocateInOutNodeBuffers(
|
||||
input_sizes.size(), input_sizes.data(), output_sizes.size(),
|
||||
output_sizes.data());
|
||||
return success;
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::Finalize() { return soc_interface_Finalize(); }
|
||||
bool HexagonControlWrapper::SetupGraph() {
|
||||
// Copy graph transfer info to modify to adapt hexnn library
|
||||
GraphTransferInfo& graph_transfer_info =
|
||||
graph_transferer_.GetMutableGraphTransferInfo();
|
||||
|
||||
// Overwrite op type of input nodes for hexagon
|
||||
for (const GraphTransferGraphInputNodeInfo& graph_input :
|
||||
graph_transfer_info.graph_input_node_info()) {
|
||||
GraphTransferNodeInfo* node_info =
|
||||
FindNodeInfo(graph_input.name(), &graph_transfer_info);
|
||||
CHECK_NE(node_info, nullptr);
|
||||
}
|
||||
|
||||
// Generate a new output node which is connected to graph output node
|
||||
// TODO(satok): Support multiple output nodes
|
||||
CHECK_EQ(graph_transfer_info.graph_output_node_info_size(), 1);
|
||||
for (const GraphTransferGraphOutputNodeInfo& graph_output :
|
||||
graph_transfer_info.graph_output_node_info()) {
|
||||
const int new_output_node_id = graph_transfer_info.node_info_size() +
|
||||
graph_transfer_info.const_node_info_size() +
|
||||
2 /* offset for ids */;
|
||||
// Register a new output node
|
||||
GraphTransferNodeInfo& new_output_node_info =
|
||||
*graph_transfer_info.add_node_info();
|
||||
new_output_node_info.set_name(OUTPUT_OP_NAME);
|
||||
new_output_node_info.set_node_id(new_output_node_id);
|
||||
new_output_node_info.set_type_name(OUTPUT_OP_NAME);
|
||||
new_output_node_info.set_soc_op_id(
|
||||
HexagonOpsDefinitions::getInstance().GetOpIdFor(OUTPUT_OP_NAME, {}));
|
||||
new_output_node_info.set_padding_id(0 /* PADDING_NA_ID */);
|
||||
new_output_node_info.set_input_count(1);
|
||||
new_output_node_info.set_output_count(0);
|
||||
|
||||
const TensorId tid = ParseTensorName(graph_output.name());
|
||||
const string node_name(tid.first);
|
||||
const int port = tid.second;
|
||||
// Register node input for the new output node
|
||||
const GraphTransferNodeInfo* node_info =
|
||||
FindNodeInfo(node_name, &graph_transfer_info);
|
||||
CHECK_NE(node_info, nullptr);
|
||||
GraphTransferNodeInputInfo& node_input_info =
|
||||
*graph_transfer_info.add_node_input_info();
|
||||
node_input_info.set_node_id(new_output_node_id);
|
||||
GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
|
||||
node_input.set_node_id(node_info->node_id());
|
||||
node_input.set_output_port(port);
|
||||
}
|
||||
|
||||
if (DBG_DUMP_VERIFICATION_STRING) {
|
||||
GraphTransferer gt;
|
||||
gt.SetSerializedGraphTransferInfo(graph_transfer_info.SerializeAsString());
|
||||
gt.DumpVerificationStringOfNodeTransferParams();
|
||||
}
|
||||
|
||||
int inputs_count = 0;
|
||||
int outputs_count = 0;
|
||||
for (const GraphTransferNodeInputInfo& input_params :
|
||||
graph_transfer_info.node_input_info()) {
|
||||
inputs_count += input_params.node_input_size();
|
||||
}
|
||||
|
||||
for (const GraphTransferNodeOutputInfo& output_params :
|
||||
graph_transfer_info.node_output_info()) {
|
||||
outputs_count += output_params.max_byte_size_size();
|
||||
}
|
||||
// Allocate memory for node inputs and node outputs
|
||||
soc_interface_AllocateNodeInputAndNodeOutputArray(inputs_count,
|
||||
outputs_count);
|
||||
|
||||
// Construct node input parameters
|
||||
std::unordered_map<int, std::tuple<void*, int>> inputs_map;
|
||||
for (const GraphTransferNodeInputInfo& input_params :
|
||||
graph_transfer_info.node_input_info()) {
|
||||
const int count = input_params.node_input_size();
|
||||
CHECK(count <= MAX_IN_OUT_COUNT);
|
||||
int node_ids[MAX_IN_OUT_COUNT];
|
||||
int ports[MAX_IN_OUT_COUNT];
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const GraphTransferNodeInput& node_input = input_params.node_input(i);
|
||||
node_ids[i] = node_input.node_id() + NODE_ID_OFFSET;
|
||||
ports[i] = node_input.output_port();
|
||||
}
|
||||
void* inputs_ptr = soc_interface_SetOneNodeInputs(count, node_ids, ports);
|
||||
const int node_id = input_params.node_id();
|
||||
CHECK(inputs_map.count(node_id) == 0);
|
||||
inputs_map.emplace(node_id, std::make_tuple(inputs_ptr, count));
|
||||
}
|
||||
|
||||
// Construct node output parameters
|
||||
std::unordered_map<int, std::tuple<void*, int>> outputs_map;
|
||||
for (const GraphTransferNodeOutputInfo& output_params :
|
||||
graph_transfer_info.node_output_info()) {
|
||||
const int count = output_params.max_byte_size_size();
|
||||
CHECK(count <= MAX_IN_OUT_COUNT);
|
||||
int sizes[MAX_IN_OUT_COUNT];
|
||||
for (int i = 0; i < count; ++i) {
|
||||
const int size = output_params.max_byte_size(i);
|
||||
sizes[i] = size;
|
||||
}
|
||||
void* outputs_ptr = soc_interface_SetOneNodeOutputs(count, sizes);
|
||||
const int node_id = output_params.node_id();
|
||||
CHECK(outputs_map.count(node_id) == 0);
|
||||
outputs_map.emplace(node_id, std::make_tuple(outputs_ptr, count));
|
||||
}
|
||||
|
||||
// Instantiate graph
|
||||
soc_interface_InstantiateGraph();
|
||||
|
||||
// Initialize graph
|
||||
// 1. Setup const nodes
|
||||
for (const GraphTransferConstNodeInfo& params :
|
||||
graph_transfer_info.const_node_info()) {
|
||||
const int node_id = params.node_id();
|
||||
// TODO(satok): Stop assuming shape size is 4.
|
||||
CHECK(params.shape_size() == 4);
|
||||
const int64 shape_0 = params.shape(0);
|
||||
const int64 shape_1 = params.shape(1);
|
||||
const int64 shape_2 = params.shape(2);
|
||||
const int64 shape_3 = params.shape(3);
|
||||
const int data_size = params.data().length();
|
||||
CHECK(dummy_const_data_.count(node_id) == 0);
|
||||
auto data = dummy_const_data_.emplace(
|
||||
std::piecewise_construct, std::make_tuple(node_id), std::make_tuple());
|
||||
CHECK(data.second);
|
||||
data.first->second.resize(data_size + ALIGNMENT_BYTES - 1);
|
||||
uint8* data_ptr = FindAlignedPointer(data.first->second.data());
|
||||
std::memcpy(data_ptr, params.data().data(), data_size);
|
||||
soc_interface_AppendConstNode(params.name().c_str(),
|
||||
node_id + NODE_ID_OFFSET, shape_0, shape_1,
|
||||
shape_2, shape_3, data_ptr, data_size);
|
||||
}
|
||||
|
||||
// 2. Setup op nodes
|
||||
for (const GraphTransferNodeInfo& params : graph_transfer_info.node_info()) {
|
||||
const int node_id = params.node_id();
|
||||
const int op_id = params.soc_op_id();
|
||||
CHECK(inputs_map.count(node_id) == 1);
|
||||
CHECK(outputs_map.count(node_id) <= 1);
|
||||
// Only output node doesn't have output
|
||||
const bool has_output = outputs_map.count(node_id) == 1;
|
||||
const auto& input_ptr_and_count = inputs_map.at(node_id);
|
||||
const void* input_ptr = std::get<0>(input_ptr_and_count);
|
||||
const int input_count = std::get<1>(input_ptr_and_count);
|
||||
void* output_ptr = nullptr;
|
||||
int output_count = 0;
|
||||
if (has_output) {
|
||||
const auto& output_ptr_and_count = outputs_map.at(node_id);
|
||||
output_ptr = std::get<0>(output_ptr_and_count);
|
||||
output_count = std::get<1>(output_ptr_and_count);
|
||||
// CHECK(output_count > 0);
|
||||
}
|
||||
int padding_id = -1;
|
||||
if (params.padding_id() == 0) {
|
||||
padding_id = 0;
|
||||
} else if (params.padding_id() == Padding::SAME) {
|
||||
padding_id = 1;
|
||||
} else if (params.padding_id() == Padding::VALID) {
|
||||
padding_id = 2;
|
||||
} else {
|
||||
LOG(FATAL);
|
||||
}
|
||||
soc_interface_AppendNode(params.name().c_str(), node_id + NODE_ID_OFFSET,
|
||||
op_id, padding_id, input_ptr, input_count,
|
||||
output_ptr, output_count);
|
||||
}
|
||||
|
||||
LOG(INFO) << "Setup graph completed";
|
||||
|
||||
// 3. construct graph
|
||||
return soc_interface_ConstructGraph();
|
||||
|
||||
// Keep following comment to use dummy graph construction
|
||||
// return soc_interface_setupDummyGraph(3 /* inception version */);
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::ExecuteGraph() {
|
||||
return soc_interface_ExecuteGraph();
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::TeardownGraph() {
|
||||
soc_interface_ReleaseNodeInputAndNodeOutputArray();
|
||||
return soc_interface_TeardownGraph();
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::FillInputNode(
|
||||
const string& node_name,
|
||||
const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>& shape,
|
||||
const ConstByteArray bytes) {
|
||||
const string tensor_name = AddPort(node_name);
|
||||
CHECK(input_port_map_.count(tensor_name) > 0);
|
||||
const int port = input_port_map_.at(tensor_name);
|
||||
if (input_tensor_data_.count(port) <= 0) {
|
||||
input_tensor_data_.emplace(port, std::vector<uint8>{});
|
||||
}
|
||||
std::vector<uint8>& input_tensor_data = input_tensor_data_.at(port);
|
||||
|
||||
// hexagon only supports 32bit dimension
|
||||
const int x = static_cast<int>(shape[0]);
|
||||
const int y = static_cast<int>(shape[1]);
|
||||
const int z = static_cast<int>(shape[2]);
|
||||
const int d = static_cast<int>(shape[3]);
|
||||
|
||||
const uint64 byte_size = x * y * z * d * DataTypeSize(std::get<2>(bytes));
|
||||
CHECK_EQ(byte_size, std::get<1>(bytes));
|
||||
input_tensor_data.resize(byte_size + ALIGNMENT_BYTES);
|
||||
uint8* data_ptr = FindAlignedPointer(input_tensor_data.data());
|
||||
|
||||
if (DBG_USE_DUMMY_INPUT) {
|
||||
std::memset(data_ptr, 0, byte_size);
|
||||
} else {
|
||||
std::memcpy(data_ptr, std::get<0>(bytes), byte_size);
|
||||
}
|
||||
|
||||
return soc_interface_FillInputNodeWithPort(port, x, y, z, d, data_ptr,
|
||||
byte_size);
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::ReadOutputNode(
|
||||
const string& node_name, TensorAllocatorFunc tensor_allocator) {
|
||||
CHECK_NE(execute_info_, nullptr);
|
||||
TensorShape output_shape;
|
||||
// TODO(satok): Switch shape corresponding to input shape
|
||||
for (int i = 0; i < execute_info_->graph_output_node_name_size(); ++i) {
|
||||
if (execute_info_->graph_output_node_name(i) == node_name) {
|
||||
for (const TensorShapeProto::Dim& dim :
|
||||
execute_info_->default_graph_output_tensor_shape(i).shape().dim()) {
|
||||
output_shape.AddDim(dim.size());
|
||||
}
|
||||
break;
|
||||
}
|
||||
}
|
||||
std::vector<ByteArray> outputs;
|
||||
ReadOutputNode(node_name, &outputs);
|
||||
CHECK_EQ(1, outputs.size());
|
||||
ByteArray& output = outputs[0];
|
||||
Tensor* output_tensor = tensor_allocator(output_shape);
|
||||
CHECK(output_tensor->TotalBytes() >= std::get<1>(output))
|
||||
<< output_tensor->TotalBytes() << ", " << std::get<1>(output);
|
||||
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
|
||||
std::get<0>(output), std::get<1>(output), output_tensor));
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::ReadOutputNode(
|
||||
const string& node_name, std::vector<ByteArray>* const outputs) {
|
||||
CHECK(outputs != nullptr);
|
||||
ByteArray output;
|
||||
const string tensor_name = AddPort(node_name);
|
||||
CHECK(output_port_map_.count(tensor_name) > 0);
|
||||
const int port = output_port_map_.at(tensor_name);
|
||||
soc_interface_ReadOutputNodeWithPort(
|
||||
port, &std::get<0>(output),
|
||||
reinterpret_cast<uint64_t*>(&std::get<1>(output)));
|
||||
// TODO: Accept all results
|
||||
// std::get<2>(output) = DT_FLOAT;
|
||||
outputs->emplace_back(output);
|
||||
return true;
|
||||
}
|
||||
|
||||
Status HexagonControlWrapper::FuseRemoteGraph(
|
||||
const GraphDef& original_graph_def, const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
|
||||
const std::unordered_set<string> fused_node_names =
|
||||
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
|
||||
original_graph_def, HexagonOpsDefinitions::getInstance());
|
||||
// TODO(satok): We may want to place shape and type inside this function
|
||||
// if they are not placed in the given graph.
|
||||
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
|
||||
original_graph_def, inputs, outputs, REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX,
|
||||
fused_node_names, REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||
/*require_shape_type=*/true, fused_graph_def));
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::FillInputNode(const string& node_name,
|
||||
const Tensor& tensor) {
|
||||
StringPiece tensor_data = tensor.tensor_data();
|
||||
const ConstByteArray ba =
|
||||
ConstByteArray(reinterpret_cast<const uint8*>(tensor_data.data()),
|
||||
tensor_data.size(), tensor.dtype());
|
||||
if (DBG_DUMP_INPUT_TENSOR_AS_FLOAT_DATA) {
|
||||
LOG(INFO) << "Input tensor data: element size = " << tensor.NumElements()
|
||||
<< ", byte syze = " << tensor.TotalBytes();
|
||||
std::stringstream line;
|
||||
for (int i = 0; i < tensor.NumElements(); ++i) {
|
||||
line << tensor.flat<float>().data()[i] << ", ";
|
||||
if ((i - 2) % 3 == 0 || i == tensor.NumElements() - 1) {
|
||||
LOG(INFO) << "(" << ((i - 2) / 3) << ") " << line.str();
|
||||
line.str("");
|
||||
line.clear();
|
||||
}
|
||||
}
|
||||
}
|
||||
const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE> shape =
|
||||
GraphTransferer::ToTensorShapeArray(tensor.shape());
|
||||
FillInputNode(node_name, shape, ba);
|
||||
return true;
|
||||
}
|
||||
|
||||
bool HexagonControlWrapper::IsEnabled() const { return true; };
|
||||
} // namespace tensorflow
|
@ -1,91 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
vcyou may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
|
||||
|
||||
#include <unordered_map>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
/*
|
||||
HexagonControlWrapper is implementing interfaces in IRemoteFusedGraphExecutor.
|
||||
This class calls APIs on hexagon via hexagon control binary.
|
||||
TODO(satok): Add more documents about hexagon control binary.
|
||||
*/
|
||||
class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
|
||||
public:
|
||||
using ByteArray =
|
||||
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
|
||||
static constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
|
||||
"build_hexagon_remote_fused_graph_executor";
|
||||
|
||||
HexagonControlWrapper() = default;
|
||||
int GetVersion() final;
|
||||
bool Init(const RemoteFusedGraphExecuteInfo& info) final;
|
||||
bool Finalize() final;
|
||||
bool SetupGraph() final;
|
||||
bool ExecuteGraph() final;
|
||||
bool TeardownGraph() final;
|
||||
bool FillInputNode(const string& node_name, const Tensor& tensor) final;
|
||||
bool ReadOutputNode(const string& node_name,
|
||||
TensorAllocatorFunc tensor_allocator) final;
|
||||
Status FuseRemoteGraph(const GraphDef& original_graph_def,
|
||||
const std::vector<string>& inputs,
|
||||
const std::vector<string>& outputs,
|
||||
GraphDef* fused_graph_def) final;
|
||||
bool IsEnabled() const final;
|
||||
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
|
||||
|
||||
private:
|
||||
using ConstByteArray = std::tuple<const uint8* /* data */, uint64 /* size */,
|
||||
DataType /* type */>;
|
||||
|
||||
bool FillInputNode(
|
||||
const string& node_name,
|
||||
const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>& shape,
|
||||
const ConstByteArray bytes);
|
||||
|
||||
// CAVEAT: Need offset as HVX library reserves some ids
|
||||
static constexpr int NODE_ID_OFFSET = 0x10000;
|
||||
|
||||
static GraphTransferNodeInfo* FindNodeInfo(
|
||||
const string& name, GraphTransferInfo* graph_transfer_info);
|
||||
|
||||
const RemoteFusedGraphExecuteInfo* execute_info_{};
|
||||
GraphTransferer graph_transferer_{};
|
||||
// Dummy float array for input node.
|
||||
// TODO(satok): Use actual data passed by FillInputNode and remove
|
||||
// std::vector<float> dummy_input_float_{};
|
||||
std::unordered_map<int, std::vector<uint8>> input_tensor_data_{};
|
||||
// Dummy byte array for const node.
|
||||
// TODO(satok): Remove
|
||||
std::unordered_map<int, std::vector<uint8>> dummy_const_data_{};
|
||||
|
||||
std::unordered_map<string, int> input_port_map_{};
|
||||
std::unordered_map<string, int> output_port_map_{};
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HexagonControlWrapper);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
|
@ -1,604 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
/* Before calling this test program, download a model as follows.
|
||||
$ curl
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb
|
||||
\ -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
|
||||
$ adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
|
||||
/data/local/tmp
|
||||
$ curl
|
||||
https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
|
||||
-o /tmp/imagenet_comp_graph_label_strings.txt
|
||||
adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
|
||||
*/
|
||||
|
||||
// define EIGEN_USE_THREADS to include quantization_utils.h
|
||||
#define EIGEN_USE_THREADS
|
||||
|
||||
#include <memory>
|
||||
|
||||
#include "absl/base/casts.h"
|
||||
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
|
||||
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_shape.pb.h"
|
||||
#include "tensorflow/core/framework/tensor_testutil.h"
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/kernels/quantization_utils.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/lib/io/path.h"
|
||||
#include "tensorflow/core/lib/strings/str_util.h"
|
||||
#include "tensorflow/core/platform/env.h"
|
||||
#include "tensorflow/core/platform/profile_utils/clock_cycle_profiler.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/session_options.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
using ByteArray = HexagonControlWrapper::ByteArray;
|
||||
|
||||
constexpr const char* const IMAGE_FILENAME = "/data/local/tmp/img_299x299.bmp";
|
||||
constexpr const char* const MODEL_FILENAME =
|
||||
"/data/local/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb";
|
||||
constexpr const char* const MODEL_WITH_QUANTIZED_INPUT_FILENAME =
|
||||
"/data/local/tmp/"
|
||||
"tensorflow_inception_v3_stripped_optimized_quantized_with_quantized_input."
|
||||
"pb";
|
||||
constexpr const char* const FUSED_MODEL_FILENAME =
|
||||
"/data/local/tmp/"
|
||||
"tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb";
|
||||
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME =
|
||||
"remote_fused_graph_execute_node";
|
||||
constexpr bool USE_SHAPE_INFERENCE = false;
|
||||
|
||||
const bool DBG_DUMP_FLOAT_DATA = false;
|
||||
const int WIDTH = 299;
|
||||
const int HEIGHT = 299;
|
||||
const int DEPTH = 3;
|
||||
const int EXPECTED_FIRST_RESULT_ID = 59;
|
||||
const int EXECUTION_REPEAT_COUNT = 10;
|
||||
|
||||
static void CheckHexagonControllerVersion() {
|
||||
HexagonControlWrapper hexagon_control_wrapper;
|
||||
const int version = hexagon_control_wrapper.GetVersion();
|
||||
ASSERT_GE(version, 1);
|
||||
LOG(INFO) << "Hexagon controller version is " << version;
|
||||
}
|
||||
|
||||
static void DumpTop10Results(const int byte_size,
|
||||
const float* const float_array) {
|
||||
const int element_count = byte_size / sizeof(float);
|
||||
const string label_filename =
|
||||
"/data/local/tmp/imagenet_comp_graph_label_strings.txt";
|
||||
string label_str;
|
||||
TF_CHECK_OK(ReadFileToString(Env::Default(), label_filename, &label_str));
|
||||
std::vector<string> labels = str_util::Split(label_str, '\n');
|
||||
GraphTransferUtils::DumpTopNFloatResults(
|
||||
float_array, labels.data(),
|
||||
std::min(element_count, static_cast<int>(labels.size())),
|
||||
10 /* show top_n results */);
|
||||
}
|
||||
|
||||
static void DumpTop10Results(const std::vector<ByteArray>& outputs) {
|
||||
CHECK(outputs.size() == 1);
|
||||
const int byte_size = std::get<1>(outputs.at(0));
|
||||
const float* float_array =
|
||||
reinterpret_cast<float*>(std::get<0>(outputs.at(0)));
|
||||
DumpTop10Results(byte_size, float_array);
|
||||
}
|
||||
|
||||
static void CheckFirstResult(const std::vector<ByteArray>& outputs,
|
||||
const int expected_first_id) {
|
||||
EXPECT_GE(outputs.size(), 1);
|
||||
const int byte_size = std::get<1>(outputs.at(0));
|
||||
const int element_count = byte_size / sizeof(float);
|
||||
const float* float_array =
|
||||
reinterpret_cast<float*>(std::get<0>(outputs.at(0)));
|
||||
EXPECT_GE(element_count, 1);
|
||||
std::vector<string> labels(element_count);
|
||||
std::priority_queue<std::tuple<float, int, string>> queue =
|
||||
GraphTransferUtils::GetTopNFloatResults(float_array, labels.data(),
|
||||
element_count);
|
||||
const std::tuple<float, int, string>& entry = queue.top();
|
||||
EXPECT_EQ(expected_first_id, std::get<1>(entry));
|
||||
}
|
||||
|
||||
static void LoadImage(std::vector<float>* img_floats_ptr) {
|
||||
CHECK(img_floats_ptr != nullptr);
|
||||
std::vector<float>& img_floats = *img_floats_ptr;
|
||||
// Read the data from the bitmap file into memory
|
||||
string bmp;
|
||||
TF_CHECK_OK(ReadFileToString(Env::Default(), IMAGE_FILENAME, &bmp));
|
||||
const int fsize = bmp.size();
|
||||
LOG(INFO) << "Read " << IMAGE_FILENAME << ", size = " << fsize << "bytes";
|
||||
const int64 pixel_count = WIDTH * HEIGHT * DEPTH;
|
||||
CHECK(fsize >= 22 /* pos of height */ + sizeof(int));
|
||||
CHECK(bmp.data() != nullptr);
|
||||
uint8* const img_bytes = absl::bit_cast<uint8*>(bmp.data());
|
||||
const int header_size = *(reinterpret_cast<int*>(img_bytes + 10));
|
||||
LOG(INFO) << "header size = " << header_size;
|
||||
const int size = *(reinterpret_cast<int*>(img_bytes + 14));
|
||||
LOG(INFO) << "image size = " << size;
|
||||
const int width = *(reinterpret_cast<int*>(img_bytes + 18));
|
||||
LOG(INFO) << "width = " << width;
|
||||
const int height = *(reinterpret_cast<int*>(img_bytes + 22));
|
||||
LOG(INFO) << "height = " << height;
|
||||
CHECK(fsize >= (WIDTH + 1) * WIDTH * 3 + header_size);
|
||||
|
||||
uint8* const bmp_pixels = &img_bytes[header_size];
|
||||
|
||||
img_floats.resize(pixel_count);
|
||||
int src_pixel_index = 0;
|
||||
CHECK(pixel_count % 3 == 0);
|
||||
for (int i = 0; i < pixel_count / 3; ++i) {
|
||||
const int src_pos = 3 * src_pixel_index;
|
||||
const int dst_pos = 3 * i;
|
||||
++src_pixel_index;
|
||||
CHECK(src_pos + 2 + header_size < fsize);
|
||||
CHECK(dst_pos + 2 < pixel_count);
|
||||
// Convert (B, G, R) in bitmap to (R, G, B)
|
||||
img_floats[dst_pos] =
|
||||
(static_cast<float>(bmp_pixels[src_pos + 2]) - 128.0f) / 128.0f;
|
||||
img_floats[dst_pos + 1] =
|
||||
(static_cast<float>(bmp_pixels[src_pos + 1]) - 128.0f) / 128.0f;
|
||||
img_floats[dst_pos + 2] =
|
||||
(static_cast<float>(bmp_pixels[src_pos]) - 128.0f) / 128.0f;
|
||||
if (DBG_DUMP_FLOAT_DATA) {
|
||||
LOG(INFO) << i << " (" << img_floats[dst_pos] << ", "
|
||||
<< img_floats[dst_pos + 1] << ", " << img_floats[dst_pos + 2]
|
||||
<< ") (" << static_cast<int>(bmp_pixels[src_pos + 2]) << ", "
|
||||
<< static_cast<int>(bmp_pixels[src_pos + 1]) << ", "
|
||||
<< static_cast<int>(bmp_pixels[src_pos]) << ")";
|
||||
}
|
||||
if (src_pixel_index % (WIDTH + 1) == (WIDTH - 1)) {
|
||||
// skip bmp padding
|
||||
++src_pixel_index;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
static void QuantizeImage(const std::vector<float>& float_vec,
|
||||
std::vector<quint8>* quint8_vec) {
|
||||
quint8_vec->resize(float_vec.size());
|
||||
for (int i = 0; i < float_vec.size(); ++i) {
|
||||
quint8_vec->at(i) = FloatToQuantized<quint8>(float_vec[i], -1.0f, 1.0f);
|
||||
}
|
||||
}
|
||||
|
||||
static Tensor BuildImageTensor(const std::vector<float>& img_floats) {
|
||||
LOG(INFO) << "Loading image finished.";
|
||||
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
|
||||
CHECK_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
|
||||
CHECK_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
|
||||
LOG(INFO) << "Copy data to tensor.";
|
||||
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
|
||||
img_tensor.TotalBytes());
|
||||
return img_tensor;
|
||||
}
|
||||
|
||||
static Tensor BuildQuantizedImageTensor(
|
||||
const std::vector<quint8>& quantized_img) {
|
||||
LOG(INFO) << "Loading image finished.";
|
||||
Tensor img_tensor(DT_QUINT8, {1, WIDTH, HEIGHT, DEPTH});
|
||||
CHECK_EQ(WIDTH * HEIGHT * DEPTH, quantized_img.size());
|
||||
CHECK_EQ(img_tensor.TotalBytes(), quantized_img.size() * sizeof(quint8));
|
||||
LOG(INFO) << "Copy data to tensor.";
|
||||
std::memcpy(img_tensor.flat<quint8>().data(), quantized_img.data(),
|
||||
img_tensor.TotalBytes());
|
||||
return img_tensor;
|
||||
}
|
||||
|
||||
/* static */ RemoteFusedGraphExecuteInfo
|
||||
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
|
||||
const GraphTransferInfo& graph_transfer_info) {
|
||||
RemoteFusedGraphExecuteInfo execute_info;
|
||||
execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
|
||||
for (const GraphTransferGraphInputNodeInfo& input :
|
||||
graph_transfer_info.graph_input_node_info()) {
|
||||
execute_info.add_graph_input_node_name(input.name());
|
||||
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
|
||||
*execute_info.add_default_graph_input_tensor_shape();
|
||||
tensor_shape_type.set_dtype(input.dtype());
|
||||
TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
|
||||
for (const int64 dim : input.shape()) {
|
||||
tensor_shape_proto.add_dim()->set_size(dim);
|
||||
}
|
||||
}
|
||||
|
||||
for (const GraphTransferGraphOutputNodeInfo& output :
|
||||
graph_transfer_info.graph_output_node_info()) {
|
||||
execute_info.add_graph_output_node_name(output.name());
|
||||
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
|
||||
*execute_info.add_default_graph_output_tensor_shape();
|
||||
tensor_shape_type.set_dtype(output.dtype());
|
||||
TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
|
||||
for (const int64 dim : output.shape()) {
|
||||
tensor_shape_proto.add_dim()->set_size(dim);
|
||||
}
|
||||
}
|
||||
|
||||
execute_info.set_serialized_executor_parameters(
|
||||
graph_transfer_info.SerializeAsString());
|
||||
return execute_info;
|
||||
}
|
||||
|
||||
static void RunInferenceByHexagonControlWrapper(const GraphTransferer& gt,
|
||||
const Tensor& img_tensor) {
|
||||
const RemoteFusedGraphExecuteInfo execute_info =
|
||||
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
|
||||
gt.GetGraphTransferInfo());
|
||||
|
||||
HexagonControlWrapper hexagon_control_wrapper;
|
||||
// 1. Initialize hexagon
|
||||
hexagon_control_wrapper.Init(execute_info);
|
||||
|
||||
// 2. Setup graph in hexagon
|
||||
hexagon_control_wrapper.SetupGraph();
|
||||
|
||||
// 3. Fill input node's output
|
||||
hexagon_control_wrapper.FillInputNode("Mul", img_tensor);
|
||||
|
||||
// 4. Execute graph
|
||||
const int64 start_time_us = Env::Default()->NowMicros();
|
||||
for (int i = 0; i < EXECUTION_REPEAT_COUNT; ++i) {
|
||||
hexagon_control_wrapper.ExecuteGraph();
|
||||
}
|
||||
const int64 end_time_us = Env::Default()->NowMicros();
|
||||
|
||||
// 5-1. Read output node's outputs
|
||||
std::vector<ByteArray> outputs;
|
||||
hexagon_control_wrapper.ReadOutputNode("softmax", &outputs);
|
||||
|
||||
// 5-2. Dump results
|
||||
DumpTop10Results(outputs);
|
||||
CheckFirstResult(outputs, EXPECTED_FIRST_RESULT_ID);
|
||||
LOG(INFO) << "Average execution time = "
|
||||
<< (end_time_us - start_time_us) / EXECUTION_REPEAT_COUNT << "us";
|
||||
|
||||
// 6. Teardown graph in hexagon
|
||||
hexagon_control_wrapper.TeardownGraph();
|
||||
|
||||
// 7. Finalize hexagon
|
||||
hexagon_control_wrapper.Finalize();
|
||||
}
|
||||
|
||||
static void RunFusedGraph(const GraphDef& fused_graph_def) {
|
||||
// Setup input tensor
|
||||
std::vector<float> img_floats;
|
||||
LoadImage(&img_floats);
|
||||
|
||||
LOG(INFO) << "Ioading image finished.";
|
||||
const Tensor img_tensor = BuildImageTensor(img_floats);
|
||||
|
||||
// Setup session
|
||||
std::vector<Tensor> output_tensors;
|
||||
SessionOptions session_options;
|
||||
session_options.env = Env::Default();
|
||||
std::unique_ptr<Session> session =
|
||||
std::unique_ptr<Session>(NewSession(session_options));
|
||||
TF_ASSERT_OK(session->Create(fused_graph_def));
|
||||
|
||||
// Setup session arguments
|
||||
RunOptions run_options;
|
||||
run_options.set_trace_level(RunOptions::FULL_TRACE);
|
||||
RunMetadata run_metadata;
|
||||
|
||||
std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
|
||||
input_tensors.emplace_back("Mul", img_tensor);
|
||||
std::vector<string> output_node_names;
|
||||
output_node_names.emplace_back(REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME);
|
||||
|
||||
LOG(INFO) << "Run graph";
|
||||
// Run inference with all node as output
|
||||
TF_ASSERT_OK(session->Run(run_options, input_tensors, output_node_names, {},
|
||||
&output_tensors, &run_metadata));
|
||||
ASSERT_EQ(1, output_tensors.size());
|
||||
const Tensor& output_tensor = output_tensors.at(0);
|
||||
LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes();
|
||||
LOG(INFO) << "Output shape = " << output_tensor.shape().DebugString();
|
||||
DumpTop10Results(
|
||||
output_tensor.TotalBytes(),
|
||||
reinterpret_cast<const float*>(output_tensor.flat<float>().data()));
|
||||
}
|
||||
|
||||
static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
|
||||
const GraphTransferInfo& gfi1) {
|
||||
LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", "
|
||||
<< gfi1.const_node_info_size();
|
||||
|
||||
// 1. check node_info
|
||||
ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
|
||||
for (int i = 0; i < gfi0.node_info_size(); ++i) {
|
||||
const GraphTransferNodeInfo& ni0 = gfi0.node_info(i);
|
||||
const GraphTransferNodeInfo& ni1 = gfi1.node_info(i);
|
||||
EXPECT_EQ(ni0.DebugString(), ni1.DebugString());
|
||||
EXPECT_EQ(ni0.ByteSizeLong(), ni1.ByteSizeLong());
|
||||
}
|
||||
|
||||
// 2. check const_node_info
|
||||
ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size());
|
||||
for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
|
||||
const GraphTransferConstNodeInfo& cni0 = gfi0.const_node_info(i);
|
||||
const GraphTransferConstNodeInfo& cni1 = gfi1.const_node_info(i);
|
||||
ASSERT_EQ(cni0.shape_size(), cni1.shape_size());
|
||||
for (int j = 0; j < cni0.shape_size(); ++j) {
|
||||
EXPECT_EQ(cni0.shape(j), cni1.shape(j));
|
||||
}
|
||||
EXPECT_EQ(cni0.ByteSizeLong(), cni1.ByteSizeLong());
|
||||
EXPECT_EQ(cni0.DebugString(), cni1.DebugString());
|
||||
}
|
||||
|
||||
// 3. check node_input_info
|
||||
ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size());
|
||||
for (int i = 0; i < gfi0.node_input_info_size(); ++i) {
|
||||
const GraphTransferNodeInputInfo& nii0 = gfi0.node_input_info(i);
|
||||
const GraphTransferNodeInputInfo& nii1 = gfi1.node_input_info(i);
|
||||
EXPECT_EQ(nii0.ByteSizeLong(), nii1.ByteSizeLong());
|
||||
EXPECT_EQ(nii0.DebugString(), nii1.DebugString());
|
||||
}
|
||||
|
||||
// 4. check node_output_info
|
||||
ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
|
||||
for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
|
||||
const GraphTransferNodeOutputInfo& noi0 = gfi0.node_output_info(i);
|
||||
const GraphTransferNodeOutputInfo& noi1 = gfi1.node_output_info(i);
|
||||
ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size());
|
||||
for (int j = 0; j < noi0.max_byte_size_size(); ++j) {
|
||||
EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j));
|
||||
}
|
||||
EXPECT_EQ(noi0.ByteSizeLong(), noi1.ByteSizeLong());
|
||||
EXPECT_EQ(noi0.DebugString(), noi1.DebugString());
|
||||
}
|
||||
|
||||
// 5. check graph_input_node_info
|
||||
ASSERT_EQ(gfi0.graph_input_node_info_size(),
|
||||
gfi1.graph_input_node_info_size());
|
||||
for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) {
|
||||
const GraphTransferGraphInputNodeInfo& gini0 =
|
||||
gfi0.graph_input_node_info(i);
|
||||
const GraphTransferGraphInputNodeInfo& gini1 =
|
||||
gfi0.graph_input_node_info(i);
|
||||
EXPECT_EQ(gini0.ByteSizeLong(), gini1.ByteSizeLong());
|
||||
EXPECT_EQ(gini0.DebugString(), gini1.DebugString());
|
||||
}
|
||||
|
||||
// 6. check graph_output_node_info
|
||||
ASSERT_EQ(gfi0.graph_output_node_info_size(),
|
||||
gfi1.graph_output_node_info_size());
|
||||
for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) {
|
||||
const GraphTransferGraphOutputNodeInfo& goni0 =
|
||||
gfi0.graph_output_node_info(i);
|
||||
const GraphTransferGraphOutputNodeInfo& goni1 =
|
||||
gfi0.graph_output_node_info(i);
|
||||
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
|
||||
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
|
||||
}
|
||||
}
|
||||
|
||||
// CAVEAT: This test only runs when you specify hexagon library using
|
||||
// makefile.
|
||||
// CAVEAT: This test is disabled by default because hexagon can keep only
|
||||
// two inception graphs on memory which are allocated by other two tests.
|
||||
// Memory of these graphs are not released until process is killed right now.
|
||||
// TODO(satok): Figure out how to release memory on hexagon without process
|
||||
// termination.
|
||||
#ifdef USE_HEXAGON_LIBS
|
||||
TEST(GraphTransferer,
|
||||
DISABLED_RunInceptionV3OnHexagonExampleWithHexagonWrapper) {
|
||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
std::vector<string> output_node_names = {"softmax"};
|
||||
|
||||
GraphTransferer gt;
|
||||
gt.EnableStrictCheckMode(false);
|
||||
profile_utils::CpuUtils::EnableClockCycleProfiling();
|
||||
ClockCycleProfiler prof;
|
||||
prof.Start();
|
||||
Status status = gt.LoadGraphFromProtoFile(
|
||||
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
|
||||
false, // is_text_proto
|
||||
false, // shape_inference_for_unknown_shape
|
||||
true // dry_run_for_unknown_shape
|
||||
);
|
||||
ASSERT_TRUE(status.ok()) << status;
|
||||
prof.Stop();
|
||||
prof.DumpStatistics("LoadGraphFromProtoFile");
|
||||
|
||||
std::vector<float> img_floats;
|
||||
LoadImage(&img_floats);
|
||||
const Tensor img_tensor = BuildImageTensor(img_floats);
|
||||
RunInferenceByHexagonControlWrapper(gt, img_tensor);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer,
|
||||
DISABLED_RunInceptionV3OnHexagonExampleWithHexagonWrapperQuantizedInput) {
|
||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller "
|
||||
<< "with quantized input";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_QUINT8, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
std::vector<string> output_node_names = {"softmax"};
|
||||
|
||||
GraphTransferer gt;
|
||||
gt.EnableStrictCheckMode(false);
|
||||
profile_utils::CpuUtils::EnableClockCycleProfiling();
|
||||
ClockCycleProfiler prof;
|
||||
prof.Start();
|
||||
Status status = gt.LoadGraphFromProtoFile(
|
||||
*ops_definitions, MODEL_WITH_QUANTIZED_INPUT_FILENAME, inputs,
|
||||
output_node_names,
|
||||
/*is_text_proto=*/false,
|
||||
/*shape_inference_for_unknown_shape=*/false,
|
||||
/*dry_run_for_unknown_shape=*/true);
|
||||
ASSERT_TRUE(status.ok()) << status;
|
||||
prof.Stop();
|
||||
prof.DumpStatistics("LoadGraphFromProtoFile");
|
||||
|
||||
std::vector<float> img_floats;
|
||||
LoadImage(&img_floats);
|
||||
std::vector<quint8> quantized_img;
|
||||
QuantizeImage(img_floats, &quantized_img);
|
||||
const Tensor img_tensor = BuildQuantizedImageTensor(quantized_img);
|
||||
RunInferenceByHexagonControlWrapper(gt, img_tensor);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer,
|
||||
DISABLED_RunInceptionV3OnHexagonExampleWithHexagonWrapperShapeInference) {
|
||||
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
std::vector<string> output_node_names = {"softmax"};
|
||||
|
||||
GraphTransferer gt;
|
||||
gt.EnableStrictCheckMode(false);
|
||||
profile_utils::CpuUtils::EnableClockCycleProfiling();
|
||||
ClockCycleProfiler prof;
|
||||
prof.Start();
|
||||
Status status = gt.LoadGraphFromProtoFile(
|
||||
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
|
||||
false, // is_text_proto
|
||||
true, // shape_inference_for_unknown_shape
|
||||
false // dry_run_for_unknown_shape
|
||||
);
|
||||
ASSERT_TRUE(status.ok()) << status;
|
||||
prof.Stop();
|
||||
prof.DumpStatistics("LoadGraphFromProtoFile");
|
||||
|
||||
std::vector<float> img_floats;
|
||||
LoadImage(&img_floats);
|
||||
const Tensor img_tensor = BuildImageTensor(img_floats);
|
||||
RunInferenceByHexagonControlWrapper(gt, img_tensor);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
|
||||
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
std::vector<string> outputs = {"softmax"};
|
||||
|
||||
std::vector<float> img_floats;
|
||||
LoadImage(&img_floats);
|
||||
|
||||
LOG(INFO) << "Ioading image finished.";
|
||||
|
||||
GraphDef graph_def;
|
||||
Status status = ReadBinaryProto(Env::Default(), MODEL_FILENAME, &graph_def);
|
||||
|
||||
ASSERT_TRUE(status.ok());
|
||||
|
||||
LOG(INFO) << "Build fused graph";
|
||||
GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef(
|
||||
HexagonOpsDefinitions::getInstance(),
|
||||
REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME, inputs, outputs, &graph_def);
|
||||
|
||||
RunFusedGraph(fused_graph_def);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer, DISABLED_RunInceptionV3OnHexagonExampleWithFusedGraph) {
|
||||
LOG(INFO) << "Run inception v3 with fused graph";
|
||||
CheckHexagonControllerVersion();
|
||||
|
||||
GraphDef fused_graph_def;
|
||||
Status status =
|
||||
ReadBinaryProto(Env::Default(), FUSED_MODEL_FILENAME, &fused_graph_def);
|
||||
RunFusedGraph(fused_graph_def);
|
||||
}
|
||||
|
||||
TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
|
||||
CheckHexagonControllerVersion();
|
||||
profile_utils::CpuUtils::EnableClockCycleProfiling();
|
||||
|
||||
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
|
||||
&HexagonOpsDefinitions::getInstance();
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
|
||||
std::vector<string> output_node_names = {"softmax"};
|
||||
|
||||
GraphTransferer gt0;
|
||||
gt0.EnableStrictCheckMode(false);
|
||||
ClockCycleProfiler prof0;
|
||||
prof0.Start();
|
||||
Status status = gt0.LoadGraphFromProtoFile(
|
||||
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
|
||||
false, // is_text_proto
|
||||
false, // shape_inference_for_unknown_shape
|
||||
true // dry_run_for_unknown_shape
|
||||
);
|
||||
const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo();
|
||||
|
||||
ASSERT_TRUE(status.ok());
|
||||
prof0.Stop();
|
||||
prof0.DumpStatistics("Estimate shape by dryrun");
|
||||
|
||||
LOG(INFO) << "(0) node count: " << gfi0.node_info_size() << ", "
|
||||
<< gfi0.const_node_info_size();
|
||||
|
||||
GraphTransferer gt1;
|
||||
gt1.EnableStrictCheckMode(true);
|
||||
ClockCycleProfiler prof1;
|
||||
prof1.Start();
|
||||
status = gt1.LoadGraphFromProtoFile(
|
||||
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
|
||||
false, // is_text_proto
|
||||
true, // shape_inference_for_unknown_shape
|
||||
false // dry_run_for_unknown_shape
|
||||
);
|
||||
const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo();
|
||||
|
||||
ASSERT_TRUE(status.ok());
|
||||
prof1.Stop();
|
||||
prof1.DumpStatistics("Estiame shape by shape inference");
|
||||
|
||||
CompareGraphTransferInfo(gfi0, gfi1);
|
||||
|
||||
const RemoteFusedGraphExecuteInfo ei0 =
|
||||
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi0);
|
||||
const RemoteFusedGraphExecuteInfo ei1 =
|
||||
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi1);
|
||||
|
||||
GraphTransferInfo rgfi0;
|
||||
rgfi0.ParseFromString(ei0.serialized_executor_parameters());
|
||||
GraphTransferInfo rgfi1;
|
||||
rgfi1.ParseFromString(ei1.serialized_executor_parameters());
|
||||
|
||||
CompareGraphTransferInfo(rgfi0, rgfi1);
|
||||
CompareGraphTransferInfo(gfi0, rgfi0);
|
||||
CompareGraphTransferInfo(gfi1, rgfi1);
|
||||
}
|
||||
#endif
|
||||
|
||||
} // namespace tensorflow
|
@ -1,408 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
vcyou may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
|
||||
// CAVEAT: Comment-out the following macro if you want to use experimental
|
||||
// hexagon ops.
|
||||
//#define ENABLE_EXPERIMENTAL_HEXNN_OPS
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// HVX internal supported ops names
|
||||
// TODO(satok): Remove this map once hexnn lib supports an API to retrieve op id
|
||||
// from op name and data type
|
||||
enum class HexagonOpsDefinitions::SupportedOpType {
|
||||
INPUT,
|
||||
OUTPUT,
|
||||
NOP,
|
||||
OP_CONST, /* OP_ is required to avoid compilation error on windows */
|
||||
CHECK,
|
||||
CLOSE_FLOAT32,
|
||||
CLOSE_QINT8,
|
||||
CLOSE_Q_QINT8,
|
||||
CLOSE_INT32,
|
||||
CLOSE_QINT32,
|
||||
PPRINT_8,
|
||||
PPRINT_32,
|
||||
PPRINT_FLOAT,
|
||||
PREFREE,
|
||||
FLATTEN,
|
||||
|
||||
#ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS
|
||||
// With Reference
|
||||
QUANTIZEDCONV2D_8X8TO32,
|
||||
QUANTIZEDCONV2D_8X8TO32_REF,
|
||||
QUANTIZEDMATMUL_8X8TO32,
|
||||
QUANTIZEDMATMUL_8X8TO32_REF,
|
||||
QUANTIZEDOWNANDSHRINKRANGE_32TO8,
|
||||
QUANTIZEDOWNANDSHRINKRANGE_32TO8_REF,
|
||||
QUANTIZEDRELU_8,
|
||||
QUANTIZEDRELU_8_REF,
|
||||
QUANTIZEDRELUX_8,
|
||||
QUANTIZEDRELUX_8_REF,
|
||||
QUANTIZEDMAXPOOL_8,
|
||||
QUANTIZEDMAXPOOL_8_REF,
|
||||
QUANTIZEDAVGPOOL_8,
|
||||
QUANTIZEDAVGPOOL_8_REF,
|
||||
QUANTIZEDCONCAT_8,
|
||||
QUANTIZEDCONCAT_8_REF,
|
||||
QUANTIZEDBIASADD_8P8TO32,
|
||||
QUANTIZEDBIASADD_8P8TO32_REF,
|
||||
MIN_F,
|
||||
MIN_F_REF,
|
||||
MAX_F,
|
||||
MAX_F_REF,
|
||||
QUANTIZE,
|
||||
QUANTIZE_REF,
|
||||
DEQUANTIZE,
|
||||
DEQUANTIZE_REF,
|
||||
SUPERNODE_8X8P8TO8,
|
||||
SUPERNODE_8X8P8TO8_REF,
|
||||
|
||||
QUANTIZEDFLATTEN,
|
||||
SOFTMAX_F,
|
||||
CONV2D_F,
|
||||
MATMUL_F,
|
||||
RELU_F,
|
||||
RELUX_F,
|
||||
AVGPOOL_F,
|
||||
MAXPOOL_F,
|
||||
CONCAT_F,
|
||||
BIASADD_F,
|
||||
LRN_F,
|
||||
|
||||
VARIABLE,
|
||||
ASSIGN,
|
||||
RESHAPE,
|
||||
QUANTIZED_RESHAPE,
|
||||
TANH_F,
|
||||
SIGMOID_F,
|
||||
SLICE_8,
|
||||
SLICE_F,
|
||||
QUANTIZED_SLICE_8,
|
||||
ADD_F,
|
||||
MUL_F,
|
||||
MINIMUM_F,
|
||||
MAXIMUM_F,
|
||||
|
||||
REQUANTIZE_32_TO_8,
|
||||
REQUANTIZE_32_TO_8_REF,
|
||||
REQUANTIZATION_RANGE_32,
|
||||
REQUANTIZATION_RANGE_32_REF,
|
||||
|
||||
NEG_F,
|
||||
SUB_F,
|
||||
ADD_N_F,
|
||||
RANGE_INT32,
|
||||
RANK_INT32,
|
||||
TRANSPOSE_INT32,
|
||||
TRANSPOSE_F,
|
||||
INSTANCE_NORM_F,
|
||||
QUANTIZED_INSTANCENORM_8,
|
||||
QUANTIZED_INSTANCENORM_8_REF,
|
||||
SUB_INT32,
|
||||
ADD_INT32,
|
||||
SPLIT_F,
|
||||
DEQUANTIZE_QINT32_F,
|
||||
PRELU_F,
|
||||
QUANTIZED_PRELU_8,
|
||||
SUM_F,
|
||||
PROD_F,
|
||||
MUL_INT32,
|
||||
LOGICAL_AND_INT32,
|
||||
LOGICALOR_INT32,
|
||||
LOGICAL_XOR_INT32,
|
||||
SPAPE_INT32,
|
||||
PACK_INT32,
|
||||
MIRROR_PAD_F,
|
||||
RESIZE_NEAREST_NEIGHBOR_F,
|
||||
STRIDED_SLICE_INT32,
|
||||
STRIDED_SLICE_F,
|
||||
EXPAND_DIMS_INT32,
|
||||
EXPAND_DIMS_F,
|
||||
|
||||
LOG_SOFTMAX_F,
|
||||
SPLIT_INT32,
|
||||
QUANTIZED_SPLIT_8,
|
||||
|
||||
DECONV_F,
|
||||
QUANTIZED_DECONV_8X8TO32,
|
||||
QUANTIZED_DECONV_8X8TO32_REF,
|
||||
|
||||
QUANTIZED_MUL_8x8to32,
|
||||
QUANTIZED_MUL_8x8to32_REF,
|
||||
QUANTIZED_ADD_8p8to32,
|
||||
QUANTIZED_ADD_8p8to32_REF,
|
||||
QUANTIZED_SIGMOID_8,
|
||||
QUANTIZED_SIGMOID_8_REF,
|
||||
QUANTIZED_TANH_8,
|
||||
QUANTIZED_TANH_8_REF,
|
||||
QUANTIZED_SOFTMAX_8,
|
||||
QUANTIZED_SOFTMAX_8_REF,
|
||||
QUANTIZED_LRN_8,
|
||||
QUANTIZED_LRN_8_REF,
|
||||
QUANTIZED_PAD2D_FRAME_8P,
|
||||
QUANTIZED_PAD2D_FRAME_8P_REF,
|
||||
QUANTIZED_SUB_8P8TO32,
|
||||
QUANTIZED_SUB_8P8TO32_REF,
|
||||
QUANTIZED_MAXIMUM_8,
|
||||
QUANTIZED_MAXIMUM_8_REF,
|
||||
QUANTIZED_MINIMUM_8,
|
||||
QUANTIZED_MINIMUM_8_REF,
|
||||
|
||||
PAD_F,
|
||||
SPACE_TO_BATCH_ND_F,
|
||||
BATCH_TO_SPACE_ND_F,
|
||||
RESIZE_BILINEAR_F,
|
||||
CONCAT_V2_F,
|
||||
|
||||
#else
|
||||
// With Reference
|
||||
QUANTIZEDCONV2D_8X8TO32,
|
||||
QUANTIZEDCONV2D_8X8TO32_REF,
|
||||
QUANTIZEDMATMUL_8X8TO32,
|
||||
QUANTIZEDMATMUL_8X8TO32_REF,
|
||||
QUANTIZEDOWNANDSHRINKRANGE_32TO8,
|
||||
QUANTIZEDOWNANDSHRINKRANGE_32TO8_REF,
|
||||
QUANTIZEDRELU_8,
|
||||
QUANTIZEDRELU_8_REF,
|
||||
QUANTIZEDRELUX_8,
|
||||
QUANTIZEDRELUX_8_REF,
|
||||
QUANTIZEDSIGMOID_8,
|
||||
QUANTIZEDSIGMOID_8_REF,
|
||||
QUANTIZEDTANH_8,
|
||||
QUANTIZEDTANH_8_REF,
|
||||
QUANTIZEDMAXPOOL_8,
|
||||
QUANTIZEDMAXPOOL_8_REF,
|
||||
QUANTIZEDAVGPOOL_8,
|
||||
QUANTIZEDAVGPOOL_8_REF,
|
||||
QUANTIZEDCONCAT_8,
|
||||
QUANTIZEDCONCAT_8_REF,
|
||||
QUANTIZEDBIASADD_8P8TO32,
|
||||
QUANTIZEDBIASADD_8P8TO32_REF,
|
||||
QUANTIZEDSOFTMAX_8,
|
||||
QUANTIZEDSOFTMAX_8_REF,
|
||||
QUANTIZEDLRN_8,
|
||||
QUANTIZEDLRN_8_REF,
|
||||
MIN_F,
|
||||
MIN_F_REF,
|
||||
MAX_F,
|
||||
MAX_F_REF,
|
||||
QUANTIZE,
|
||||
QUANTIZE_REF,
|
||||
DEQUANTIZE,
|
||||
DEQUANTIZE_REF,
|
||||
SUPERNODE_8X8P8TO8,
|
||||
SUPERNODE_8X8P8TO8_REF,
|
||||
|
||||
QUANTIZEDFLATTEN,
|
||||
SOFTMAX_F,
|
||||
CONV2D_F,
|
||||
MATMUL_F,
|
||||
RELU_F,
|
||||
RELUX_F,
|
||||
AVGPOOL_F,
|
||||
MAXPOOL_F,
|
||||
CONCAT_F,
|
||||
BIASADD_F,
|
||||
LRN_F,
|
||||
|
||||
VARIABLE,
|
||||
ASSIGN,
|
||||
RESHAPE,
|
||||
QUANTIZED_RESHAPE,
|
||||
TANH_F,
|
||||
SIGMOID_F,
|
||||
SLICE_8,
|
||||
SLICE_F,
|
||||
QUANTIZED_SLICE_8,
|
||||
ADD_F,
|
||||
MUL_F,
|
||||
MINIMUM_F,
|
||||
MAXIMUM_F,
|
||||
|
||||
REQUANTIZE_32_TO_8,
|
||||
REQUANTIZE_32_TO_8_REF,
|
||||
REQUANTIZATION_RANGE_32,
|
||||
REQUANTIZATION_RANGE_32_REF,
|
||||
|
||||
NEG_F,
|
||||
SUB_F,
|
||||
ADD_N_F,
|
||||
RANGE_INT32,
|
||||
RANK_INT32,
|
||||
TRANSPOSE_INT32,
|
||||
TRANSPOSE_F,
|
||||
INSTANCE_NORM_F,
|
||||
QUANTIZED_INSTANCENORM_8,
|
||||
QUANTIZED_INSTANCENORM_8_REF,
|
||||
SUB_INT32,
|
||||
ADD_INT32,
|
||||
SPLIT_F,
|
||||
DEQUANTIZE_QINT32_F,
|
||||
PRELU_F,
|
||||
QUANTIZED_PRELU_8,
|
||||
SUM_F,
|
||||
PROD_F,
|
||||
MUL_INT32,
|
||||
LOGICAL_AND_INT32,
|
||||
LOGICALOR_INT32,
|
||||
LOGICAL_XOR_INT32,
|
||||
SPAPE_INT32,
|
||||
PACK_INT32,
|
||||
MIRROR_PAD_F,
|
||||
RESIZE_NEAREST_NEIGHBOR_F,
|
||||
STRIDED_SLICE_INT32,
|
||||
STRIDED_SLICE_F,
|
||||
EXPAND_DIMS_INT32,
|
||||
EXPAND_DIMS_F,
|
||||
|
||||
LOG_SOFTMAX_F,
|
||||
SPLIT_INT32,
|
||||
QUANTIZED_SPLIT_8,
|
||||
|
||||
DECONV_F,
|
||||
QUANTIZED_DECONV_8X8TO32,
|
||||
QUANTIZED_DECONV_8X8TO32_REF,
|
||||
#endif
|
||||
|
||||
SUPPORTED_OP_TYPE_COUNT // TERMINATOR. DO NOT REMOVE
|
||||
};
|
||||
|
||||
/* static */ void HexagonOpsDefinitions::EmplaceOpType(
|
||||
const string& op_type, const DataTypeVector& dt_vec,
|
||||
const SupportedOpType supported_op_type,
|
||||
std::unordered_map<string, std::vector<DataTypeToOp>>* map) {
|
||||
if (map->count(op_type) <= 0) {
|
||||
map->emplace(op_type, std::vector<DataTypeToOp>());
|
||||
}
|
||||
map->at(op_type).emplace_back(
|
||||
std::forward_as_tuple(dt_vec, supported_op_type));
|
||||
}
|
||||
|
||||
/* static */ std::unordered_map<
|
||||
string, std::vector<HexagonOpsDefinitions::DataTypeToOp>>
|
||||
HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
|
||||
std::unordered_map<string, std::vector<DataTypeToOp>> op_map;
|
||||
// Custom Op name
|
||||
EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map);
|
||||
EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map);
|
||||
EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map);
|
||||
// Special op type for hexagon
|
||||
EmplaceOpType("FLATTEN", {}, SupportedOpType::FLATTEN, &op_map);
|
||||
// Tensorflow op name
|
||||
// CAVEAT: Keep order of SupportedOpType
|
||||
EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map);
|
||||
EmplaceOpType("Placeholder", {}, SupportedOpType::NOP, &op_map);
|
||||
EmplaceOpType("Const", {}, SupportedOpType::OP_CONST, &op_map);
|
||||
EmplaceOpType("QuantizedConv2D", {}, SupportedOpType::QUANTIZEDCONV2D_8X8TO32,
|
||||
&op_map);
|
||||
EmplaceOpType("QuantizedMatMul", {}, SupportedOpType::QUANTIZEDMATMUL_8X8TO32,
|
||||
&op_map);
|
||||
EmplaceOpType("QuantizeDownAndShrinkRange", {},
|
||||
SupportedOpType::QUANTIZEDOWNANDSHRINKRANGE_32TO8, &op_map);
|
||||
EmplaceOpType("QuantizedRelu", {}, SupportedOpType::QUANTIZEDRELU_8, &op_map);
|
||||
EmplaceOpType("QuantizedReluX", {}, SupportedOpType::QUANTIZEDRELUX_8,
|
||||
&op_map);
|
||||
EmplaceOpType("QuantizedMaxPool", {}, SupportedOpType::QUANTIZEDMAXPOOL_8,
|
||||
&op_map);
|
||||
EmplaceOpType("QuantizedAvgPool", {}, SupportedOpType::QUANTIZEDAVGPOOL_8,
|
||||
&op_map);
|
||||
EmplaceOpType("QuantizedConcat", {}, SupportedOpType::QUANTIZEDCONCAT_8,
|
||||
&op_map);
|
||||
EmplaceOpType("QuantizedBiasAdd", {},
|
||||
SupportedOpType::QUANTIZEDBIASADD_8P8TO32, &op_map);
|
||||
EmplaceOpType("Min", {}, SupportedOpType::MIN_F, &op_map);
|
||||
EmplaceOpType("Max", {}, SupportedOpType::MAX_F, &op_map);
|
||||
EmplaceOpType("QuantizeV2", {}, SupportedOpType::QUANTIZE, &op_map);
|
||||
EmplaceOpType("Dequantize", {}, SupportedOpType::DEQUANTIZE, &op_map);
|
||||
EmplaceOpType("Softmax", {}, SupportedOpType::SOFTMAX_F, &op_map);
|
||||
EmplaceOpType("Reshape", {}, SupportedOpType::RESHAPE, &op_map);
|
||||
EmplaceOpType("QuantizedReshape", {}, SupportedOpType::QUANTIZED_RESHAPE,
|
||||
&op_map);
|
||||
EmplaceOpType("Sigmoid", {}, SupportedOpType::SIGMOID_F, &op_map);
|
||||
EmplaceOpType("Slice", {}, SupportedOpType::SLICE_F, &op_map);
|
||||
EmplaceOpType("Add", {}, SupportedOpType::ADD_F, &op_map);
|
||||
EmplaceOpType("Mul", {}, SupportedOpType::MUL_F, &op_map);
|
||||
EmplaceOpType("Requantize", {}, SupportedOpType::REQUANTIZE_32_TO_8, &op_map);
|
||||
EmplaceOpType("RequantizationRange", {},
|
||||
SupportedOpType::REQUANTIZATION_RANGE_32, &op_map);
|
||||
EmplaceOpType("Sub", {}, SupportedOpType::SUB_F, &op_map);
|
||||
EmplaceOpType("Pack", {}, SupportedOpType::PACK_INT32, &op_map);
|
||||
EmplaceOpType("StridedSlice", {}, SupportedOpType::STRIDED_SLICE_F, &op_map);
|
||||
EmplaceOpType("ExpandDims", {}, SupportedOpType::EXPAND_DIMS_F, &op_map);
|
||||
#ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS
|
||||
EmplaceOpType("QuantizedMul", {}, SupportedOpType::QUANTIZED_MUL_8x8to32,
|
||||
&op_map);
|
||||
EmplaceOpType("QuantizedAdd", {}, SupportedOpType::QUANTIZED_ADD_8p8to32,
|
||||
&op_map);
|
||||
EmplaceOpType("Pad", {}, SupportedOpType::PAD_F, &op_map);
|
||||
EmplaceOpType("SpaceToBatchND", {}, SupportedOpType::SPACE_TO_BATCH_ND_F,
|
||||
&op_map),
|
||||
EmplaceOpType("BatchToSpaceND", {}, SupportedOpType::BATCH_TO_SPACE_ND_F,
|
||||
&op_map);
|
||||
EmplaceOpType("ResizeBilinear", {}, SupportedOpType::RESIZE_BILINEAR_F,
|
||||
&op_map);
|
||||
EmplaceOpType("ConcatV2", {}, SupportedOpType::CONCAT_V2_F, &op_map);
|
||||
EmplaceOpType("Conv2DBackpropInput", {}, SupportedOpType::DECONV_F, &op_map);
|
||||
|
||||
EmplaceOpType("Tanh", {}, SupportedOpType::TANH_F, &op_map);
|
||||
EmplaceOpType("Split", {}, SupportedOpType::SPLIT_F, &op_map);
|
||||
EmplaceOpType("Transpose", {}, SupportedOpType::TRANSPOSE_F, &op_map);
|
||||
EmplaceOpType("Concat", {}, SupportedOpType::CONCAT_F, &op_map);
|
||||
#endif
|
||||
return op_map;
|
||||
};
|
||||
|
||||
HexagonOpsDefinitions::HexagonOpsDefinitions()
|
||||
: op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {}
|
||||
|
||||
/* static */ const IRemoteFusedGraphOpsDefinitions&
|
||||
HexagonOpsDefinitions::getInstance() {
|
||||
const static HexagonOpsDefinitions instance{};
|
||||
return instance;
|
||||
}
|
||||
|
||||
int HexagonOpsDefinitions::GetTotalOpsCount() const {
|
||||
return static_cast<int>(SupportedOpType::SUPPORTED_OP_TYPE_COUNT);
|
||||
}
|
||||
|
||||
int HexagonOpsDefinitions::GetOpIdFor(const string& op_type,
|
||||
const DataTypeVector& dt_vec) const {
|
||||
if (op_name_to_soc_op_type_map_.count(op_type) > 0) {
|
||||
const std::vector<DataTypeToOp>& dt_to_op_vec =
|
||||
op_name_to_soc_op_type_map_.at(op_type);
|
||||
CHECK(!dt_to_op_vec.empty());
|
||||
// If argument DataType is empty, return the first entry.
|
||||
if (dt_vec.empty()) {
|
||||
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
|
||||
}
|
||||
// If there is only one op_id registered for empty op_vec, we assume
|
||||
// that the op supports any data types.
|
||||
if (dt_to_op_vec.size() == 1 && std::get<0>(dt_to_op_vec.front()).empty()) {
|
||||
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
|
||||
}
|
||||
for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) {
|
||||
if (std::get<0>(data_type_to_op) == dt_vec) {
|
||||
return static_cast<int>(std::get<1>(data_type_to_op));
|
||||
}
|
||||
}
|
||||
}
|
||||
return IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
|
||||
}
|
||||
} // namespace tensorflow
|
@ -1,58 +0,0 @@
|
||||
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
vcyou may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
|
||||
#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
|
||||
|
||||
#include <unordered_map>
|
||||
|
||||
#include "tensorflow/core/framework/types.h"
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
|
||||
#include "tensorflow/core/platform/macros.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// HexagonOpsDefinitions provides ops definitions supported in hexagon library
|
||||
// TODO(satok): add a functionality to call functions in hexagon library
|
||||
class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
|
||||
public:
|
||||
static const IRemoteFusedGraphOpsDefinitions& getInstance();
|
||||
|
||||
int GetTotalOpsCount() const final;
|
||||
int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final;
|
||||
|
||||
private:
|
||||
enum class SupportedOpType;
|
||||
using DataTypeToOp = std::tuple<DataTypeVector, SupportedOpType>;
|
||||
|
||||
HexagonOpsDefinitions();
|
||||
|
||||
static void EmplaceOpType(
|
||||
const string& op_type, const DataTypeVector& dt_vec,
|
||||
const SupportedOpType supported_op_type,
|
||||
std::unordered_map<string, std::vector<DataTypeToOp>>* map);
|
||||
|
||||
static std::unordered_map<string, std::vector<DataTypeToOp>>
|
||||
BuildOpNameToSocOpTypeMap();
|
||||
|
||||
const std::unordered_map<string, std::vector<DataTypeToOp>>
|
||||
op_name_to_soc_op_type_map_;
|
||||
|
||||
TF_DISALLOW_COPY_AND_ASSIGN(HexagonOpsDefinitions);
|
||||
};
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
|
@ -1,34 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace hexagon_remote_fused_graph_executor_build {
|
||||
|
||||
Status BuildRemoteFusedGraphExecutor(
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
|
||||
executor->reset(new HexagonControlWrapper());
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
|
||||
k_hexagon_remote_fused_graph_executor_build(
|
||||
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
|
||||
BuildRemoteFusedGraphExecutor);
|
||||
|
||||
} // namespace hexagon_remote_fused_graph_executor_build
|
||||
} // namespace tensorflow
|
@ -1,42 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
|
||||
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace hexagon_remote_fused_graph_executor_build {
|
||||
|
||||
Status BuildRemoteFusedGraphExecutor(
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor>* executor);
|
||||
|
||||
namespace {
|
||||
|
||||
TEST(HexagonBuildRemoteFusedGraphExecutorTest, BasicRun) {
|
||||
std::unique_ptr<IRemoteFusedGraphExecutor> executor;
|
||||
ASSERT_FALSE(static_cast<bool>(executor));
|
||||
TF_ASSERT_OK(BuildRemoteFusedGraphExecutor(&executor));
|
||||
ASSERT_TRUE(static_cast<bool>(executor));
|
||||
ASSERT_NE(RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
|
||||
"build_hexagon_remote_fused_graph_executor"),
|
||||
nullptr);
|
||||
}
|
||||
|
||||
} // namespace
|
||||
|
||||
} // namespace hexagon_remote_fused_graph_executor_build
|
||||
} // namespace tensorflow
|
@ -1,94 +0,0 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// Wraps the hexagon rewriter in a transform so it can be used as part of the
|
||||
// graph transform tool.
|
||||
// A usage example, based on the Image Understanding pipeline:
|
||||
/*
|
||||
bazel build tensorflow/tools/graph_transforms:transform_graph
|
||||
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
|
||||
--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
|
||||
--out_graph=\
|
||||
/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
|
||||
--inputs='Mul' \
|
||||
--outputs='softmax' \
|
||||
--transforms='\
|
||||
rewrite_quantized_stripped_model_for_hexagon(
|
||||
input_shape0="1,299,299,3" \
|
||||
input_type0="float" \
|
||||
)'
|
||||
*/
|
||||
|
||||
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
|
||||
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
|
||||
#include "tensorflow/tools/graph_transforms/transform_utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace graph_transforms {
|
||||
constexpr const char* const INPUT_SHAPE_PREFIX = "input_shape";
|
||||
constexpr const char* const INPUT_TYPE_PREFIX = "input_type";
|
||||
|
||||
Status RewriteQuantizedStrippedModelForHexagon(
|
||||
const GraphDef& input_graph_def, const TransformFuncContext& context,
|
||||
GraphDef* output_graph_def) {
|
||||
LOG(INFO) << "Transforming quantized stripped model to a remote fused "
|
||||
"graph execute op...";
|
||||
std::vector<std::pair<string, Tensor>> inputs;
|
||||
std::vector<string> outputs;
|
||||
for (auto i = 0; static_cast<size_t>(i) < context.input_names.size(); ++i) {
|
||||
const string& input_name = context.input_names.at(i);
|
||||
|
||||
// Get input shape
|
||||
string shape_string;
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||
INPUT_SHAPE_PREFIX + std::to_string(i), "", &shape_string));
|
||||
std::vector<string> split_shape = str_util::Split(shape_string, ',');
|
||||
std::vector<int64> dims;
|
||||
for (const string& dim : split_shape) {
|
||||
int64 tmp;
|
||||
CHECK(strings::safe_strto64(dim, &tmp));
|
||||
dims.push_back(tmp);
|
||||
}
|
||||
|
||||
// Get input data type
|
||||
string data_type_string;
|
||||
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
|
||||
INPUT_TYPE_PREFIX + std::to_string(i), "", &data_type_string));
|
||||
DataType data_type;
|
||||
CHECK(DataTypeFromString(data_type_string, &data_type))
|
||||
<< "\"" << data_type_string << "\" was an invalid type";
|
||||
|
||||
LOG(INFO) << "Input(" << i << "): name = " << input_name
|
||||
<< ", shape = " << shape_string
|
||||
<< ", type = " << data_type_string;
|
||||
|
||||
inputs.emplace_back(input_name, Tensor(data_type, TensorShape(dims)));
|
||||
}
|
||||
|
||||
for (const string& output_name : context.output_names) {
|
||||
outputs.emplace_back(output_name);
|
||||
}
|
||||
GraphDef mutable_input_graph_def = input_graph_def;
|
||||
*output_graph_def = GraphTransferUtils::BuildFusedGraphDef(
|
||||
HexagonOpsDefinitions::getInstance(), "remote_fused_graph_execute_node",
|
||||
inputs, outputs, &mutable_input_graph_def);
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
REGISTER_GRAPH_TRANSFORM("rewrite_quantized_stripped_model_for_hexagon",
|
||||
RewriteQuantizedStrippedModelForHexagon);
|
||||
|
||||
} // namespace graph_transforms
|
||||
} // namespace tensorflow
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user