Merge branch 'master' into macos_arm64_cmake

This commit is contained in:
Simon Maurer 2021-02-26 14:11:38 +01:00
commit 8817e44105
215 changed files with 5475 additions and 10300 deletions
tensorflow
c
compiler
core

View File

@ -3,7 +3,6 @@
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"if_libtpu",
"tf_cc_test",
"tf_copts",
"tf_cuda_cc_test",
@ -320,75 +319,6 @@ tf_cuda_cc_test(
],
)
cc_library(
name = "gradients_util",
srcs = [
"gradients_util.cc",
],
hdrs = [
"gradients_util.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_context",
":abstract_operation",
":abstract_tensor_handle",
":c_api",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":tape",
"@com_google_absl//absl/container:flat_hash_map",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
"//tensorflow/c:c_api",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/lib/llvm_rtti",
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
)
cc_library(
name = "mnist_gradients_testutil",
srcs = [
"mnist_gradients_testutil.cc",
],
hdrs = [
"mnist_gradients_testutil.h",
],
visibility = [
"//tensorflow:internal",
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
":tape",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:status",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "gradient_checker",
testonly = 1,
@ -436,46 +366,6 @@ tf_cuda_cc_test(
],
)
tf_cuda_cc_test(
name = "mnist_gradients_test",
size = "small",
srcs = [
"mnist_gradients_test.cc",
],
args = ["--heap_check=local"],
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + [
"no_cuda_asan", # b/173825513
],
deps = [
":abstract_tensor_handle",
":c_api_experimental",
":c_api_unified_internal",
":gradients_internal",
":gradients_util",
":mnist_gradients_testutil",
"//tensorflow/c:c_api",
"//tensorflow/c:c_test_util",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/experimental/gradients:math_grad",
"//tensorflow/c/experimental/gradients:nn_grad",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/c/experimental/ops:math_ops",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/cc/profiler",
"//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core/lib/llvm_rtti",
"//tensorflow/core/platform:tensor_float_32_utils",
"@com_google_absl//absl/container:flat_hash_set",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
],
)
cc_library(
name = "abstract_tensor_handle",
srcs = ["abstract_tensor_handle.cc"],
@ -1157,8 +1047,6 @@ filegroup(
"gradient_checker.cc",
"gradient_checker.h",
"gradients.cc", # Uses RTTI.
"gradients_util.cc",
"gradients_util.h",
"tracing_utils.h",
"tracing_utils.cc",
"*test*",

View File

@ -16,6 +16,9 @@ limitations under the License.
#define TENSORFLOW_C_EAGER_C_API_TEST_UTIL_H_
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/tf_datatype.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/tstring.h"
#include "tensorflow/core/platform/types.h"
#include "tensorflow/core/protobuf/tensorflow_server.pb.h"
@ -53,6 +56,27 @@ TFE_TensorHandle* TestTensorHandleWithDimsFloat(TFE_Context* ctx, float data[],
TFE_TensorHandle* TestTensorHandleWithDimsInt(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims);
// Return a tensor handle with given type, values and dimensions.
template <class T, TF_DataType datatype>
TFE_TensorHandle* TestTensorHandleWithDims(TFE_Context* ctx, const T* data,
const int64_t* dims, int num_dims) {
TF_Status* status = TF_NewStatus();
TF_Tensor* t = TFE_AllocateHostTensor(ctx, datatype, dims, num_dims, status);
memcpy(TF_TensorData(t), data, TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status);
CHECK_EQ(TF_OK, TF_GetCode(status)) << TF_Message(status);
TF_DeleteTensor(t);
TF_DeleteStatus(status);
return th;
}
// Return a scalar tensor handle with given values.
template <class T, TF_DataType datatype>
TFE_TensorHandle* TestScalarTensorHandle(TFE_Context* ctx, const T value) {
T data[] = {value};
return TestTensorHandleWithDims<T, datatype>(ctx, data, nullptr, 0);
}
// Return a tensor handle containing a 100x100 matrix of floats
TFE_TensorHandle* TestMatrixTensorHandle100x100(TFE_Context* ctx);

View File

@ -29,8 +29,9 @@ using namespace std;
// ================== Helper functions =================
// Fills data with values [start,end) with given step size.
void Range(vector<int>* data, int start, int end, int step = 1) {
for (int i = start; i < end; i += step) {
void Range(vector<int32_t>* data, int32_t start, int32_t end,
int32_t step = 1) {
for (int32_t i = start; i < end; i += step) {
(*data)[i] = i;
}
}
@ -72,12 +73,12 @@ Status RunAndMaybeSum(AbstractContext* ctx, Model forward,
// Will sum all dimensions, so get a Tensor containing [0,...,num_dims_out-1].
AbstractTensorHandlePtr sum_dims;
{
vector<int> vals(num_dims_out);
vector<int32_t> vals(num_dims_out);
int64_t vals_shape[] = {num_dims_out};
Range(&vals, 0, num_dims_out);
AbstractTensorHandle* sum_dims_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsInt(ctx, vals.data(), vals_shape,
1, &sum_dims_raw));
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<int32_t, TF_INT32>(
ctx, vals.data(), vals_shape, 1, &sum_dims_raw));
sum_dims.reset(sum_dims_raw);
}
@ -130,8 +131,8 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
AbstractTensorHandlePtr two_eps;
{
AbstractTensorHandle* two_eps_raw = nullptr;
TF_RETURN_IF_ERROR(
TestScalarTensorHandle(ctx, 2 * epsilon, &two_eps_raw));
TF_RETURN_IF_ERROR(TestScalarTensorHandle<float, TF_FLOAT>(
ctx, 2 * epsilon, &two_eps_raw));
two_eps.reset(two_eps_raw);
}
@ -142,7 +143,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
AbstractTensorHandlePtr thetaPlus;
{
AbstractTensorHandle* thetaPlus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<float, TF_FLOAT>(
ctx, thetaPlus_data.data(), theta_dims.data(), num_dims,
&thetaPlus_raw));
thetaPlus.reset(thetaPlus_raw);
@ -155,7 +156,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
AbstractTensorHandlePtr thetaMinus;
{
AbstractTensorHandle* thetaMinus_raw = nullptr;
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<float, TF_FLOAT>(
ctx, thetaMinus_data.data(), theta_dims.data(), num_dims,
&thetaMinus_raw));
thetaMinus.reset(thetaMinus_raw);
@ -194,7 +195,7 @@ Status CalcNumericalGrad(AbstractContext* ctx, Model forward,
}
// Populate *numerical_grad with the data from dtheta_approx.
TF_RETURN_IF_ERROR(TestTensorHandleWithDimsFloat(
TF_RETURN_IF_ERROR(TestTensorHandleWithDims<float, TF_FLOAT>(
ctx, dtheta_approx.data(), theta_dims.data(), num_dims, numerical_grad));
TF_DeleteTensor(theta_tensor);
return Status::OK();

View File

@ -117,8 +117,8 @@ TEST_P(GradientCheckerTest, TestMatMul) {
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
Status s = TestTensorHandleWithDims<float, TF_FLOAT>(ctx_.get(), A_vals,
A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
A.reset(A_raw);
}
@ -127,8 +127,8 @@ TEST_P(GradientCheckerTest, TestMatMul) {
AbstractTensorHandlePtr B;
{
AbstractTensorHandle* B_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), B_vals, B_dims, 2, &B_raw);
Status s = TestTensorHandleWithDims<float, TF_FLOAT>(ctx_.get(), B_vals,
B_dims, 2, &B_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
B.reset(B_raw);
}
@ -143,7 +143,8 @@ TEST_P(GradientCheckerTest, TestMul) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx_.get(), 2.0f, &x_raw);
Status s =
TestScalarTensorHandle<float, TF_FLOAT>(ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
@ -151,7 +152,8 @@ TEST_P(GradientCheckerTest, TestMul) {
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
Status s = TestScalarTensorHandle(ctx_.get(), 7.0f, &y_raw);
Status s =
TestScalarTensorHandle<float, TF_FLOAT>(ctx_.get(), 7.0f, &y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
y.reset(y_raw);
}

View File

@ -58,100 +58,10 @@ class CppGradients
};
Status RegisterGradients(GradientRegistry* registry) {
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
// AddV2Registerer.
TF_RETURN_IF_ERROR(registry->Register("IdentityN", IdentityNRegisterer));
TF_RETURN_IF_ERROR(RegisterNotDifferentiable(registry, "CheckNumerics"));
return Status::OK();
}
// Computes
// ignored, y = IdentityN(inputs[0], inputs[1])
// return grad(y, {inputs[0], inputs[1]})
// This should return [nullptr, 1].
Status IdentityNGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(RegisterGradients(&registry));
auto tape = std::make_unique<Tape>(/*persistent=*/false);
tape->Watch(inputs[0]);
tape->Watch(inputs[1]);
vector<AbstractTensorHandle*> identity_n_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape.get(), registry));
TF_RETURN_IF_ERROR(ops::IdentityN(
tape_ctx.get(), inputs, absl::MakeSpan(identity_n_outputs), "IdentityN"));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx,
/*targets=*/{identity_n_outputs[1]},
/*sources=*/{inputs[0], inputs[1]},
/*output_gradients=*/{}, outputs));
for (auto identity_n_output : identity_n_outputs) {
identity_n_output->Unref();
}
return Status::OK();
}
TEST_P(CppGradients, TestIdentityNGrad) {
// Pseudo-code:
//
// tape.watch(x1)
// tape.watch(x2)
// unused, y = IdentityN([x1, x2])
// outputs = tape.gradient(y, [x1, x2])
// Expected: [nullptr, 1]
//
// This test is interesting because the current implementation of GradientTape
// would return [0, 1] whereas we use build_default_zeros_grads=false here
// so we get back [nullptr, 1].
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr x1;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x1.reset(x_raw);
}
AbstractTensorHandlePtr x2;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x2.reset(x_raw);
}
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(IdentityNGradModel, ctx.get(), {x1.get(), x2.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()));
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
EXPECT_EQ(outputs[0], nullptr);
TF_Tensor* result_tensor;
s = GetValue(outputs[1], &result_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
auto result_value = static_cast<float*>(TF_TensorData(result_tensor));
EXPECT_EQ(*result_value, 1.0);
outputs[1]->Unref();
TF_DeleteTensor(result_tensor);
result_tensor = nullptr;
}
TEST_P(CppGradients, TestSetAttrString) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
@ -167,7 +77,7 @@ TEST_P(CppGradients, TestSetAttrString) {
AbstractTensorHandlePtr t;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
t.reset(x_raw);
}
@ -232,7 +142,7 @@ TEST_P(CppGradients, TestRecordOperationWithNullGradientFunctionRaises) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}

View File

@ -1,320 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/gradients_util.h"
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
namespace tensorflow {
namespace gradients {
using namespace std;
Status ScalarTensorHandleHelper(TFE_Context* ctx, float value,
TFE_TensorHandle** result) {
float data[] = {value};
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, nullptr, 0, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsFloatHelper(TFE_Context* ctx, float data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_FLOAT, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
Status TensorHandleWithDimsIntHelper(TFE_Context* ctx, int data[],
int64_t dims[], int num_dims,
TFE_TensorHandle** result) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_Tensor* t =
TFE_AllocateHostTensor(ctx, TF_INT32, &dims[0], num_dims, status.get());
memcpy(TF_TensorData(t), &data[0], TF_TensorByteSize(t));
TFE_TensorHandle* th = TFE_NewTensorHandleFromTensor(ctx, t, status.get());
*result = th;
TF_DeleteTensor(t);
return StatusFromTF_Status(status.get());
}
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(ScalarTensorHandleHelper(eager_ctx, value, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsFloatHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager;
TF_RETURN_IF_ERROR(TensorHandleWithDimsIntHelper(eager_ctx, data, dims,
num_dims, &input_eager));
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return StatusFromTF_Status(status.get());
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_TensorHandle* result_t =
TF_AbstractTensorGetEagerTensor(wrap(t), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
*result_tensor = TFE_TensorHandleResolve(result_t, status.get());
return StatusFromTF_Status(status.get());
}
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsFloat(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims) {
AbstractTensorHandlePtr A;
AbstractTensorHandle* a_raw = nullptr;
Status s = TensorHandleWithDimsInt(ctx, vals, dims, num_dims, &a_raw);
if (s.ok()) {
A.reset(a_raw);
}
return A;
}
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val) {
AbstractTensorHandlePtr y;
AbstractTensorHandle* y_raw = nullptr;
Status s = ScalarTensorHandle(ctx, val, &y_raw);
if (s.ok()) {
y.reset(y_raw);
}
return y;
}
Status UpdateWeights(AbstractContext* ctx, vector<AbstractTensorHandle*>& grads,
vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate) {
/* Update weights one by one using gradient update rule:
*
* w -= lr*grad[w]
*
* NOTE: assuming learning rate is positive
*/
int num_grads = grads.size();
vector<AbstractTensorHandle*> temp_outputs(1);
std::string update_str;
// Negate learning rate for gradient descent
TF_RETURN_IF_ERROR(ops::Neg(ctx, {learning_rate},
absl::MakeSpan(temp_outputs),
"neg_lr")); // Compute -lr
learning_rate = temp_outputs[0];
for (int i = 0; i < num_grads; i++) {
// Compute dW = -lr * grad(w[i])
update_str = "update_mul_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Mul(ctx, {learning_rate, grads[i]},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
AbstractTensorHandle* dW = temp_outputs[0];
// Compute temp = weights[i] + dW
update_str = "update_add_" + std::to_string(i);
TF_RETURN_IF_ERROR(ops::Add(ctx, {weights[i], dW},
absl::MakeSpan(temp_outputs),
update_str.c_str()));
// Update the weights
weights[i] = temp_outputs[0];
}
return Status::OK();
}
AbstractContext* BuildFunction(const char* fn_name) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TF_ExecutionContext* graph_ctx = TF_CreateFunction(fn_name, status.get());
return unwrap(graph_ctx);
}
Status CreateParamsForInputs(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
vector<AbstractTensorHandle*>* params) {
tracing::TracingTensorHandle* handle = nullptr;
for (auto input : inputs) {
PartialTensorShape shape;
TF_RETURN_IF_ERROR(input->Shape(&shape));
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(ctx)->AddParameter(
input->DataType(), shape, &handle));
params->emplace_back(handle);
}
return Status::OK();
}
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry) {
if (use_function) {
const char* fn_name = "test_fn";
std::unique_ptr<AbstractFunction> scoped_func;
// Returning null tensors from a tf.function is not supported, so we keep
// track of indices in the model's outputs are nullptr in this set.
// The FunctionDef only outputs the non-null tensors. We later pad the
// function op outputs to have nullptrs at the `null_indices`.
absl::flat_hash_set<int> null_indices;
{
AbstractContextPtr func_ctx(BuildFunction(fn_name));
vector<AbstractTensorHandle*> func_inputs;
func_inputs.reserve(inputs.size());
TF_RETURN_IF_ERROR(
CreateParamsForInputs(func_ctx.get(), inputs, &func_inputs));
vector<AbstractTensorHandle*> model_outputs;
model_outputs.resize(outputs.size());
TF_RETURN_IF_ERROR(model(func_ctx.get(), absl::MakeSpan(func_inputs),
absl::MakeSpan(model_outputs), registry));
for (auto func_input : func_inputs) {
func_input->Unref();
}
AbstractFunction* func = nullptr;
OutputList output_list;
output_list.expected_num_outputs = 0;
output_list.outputs.reserve(outputs.size());
for (int i = 0; i < model_outputs.size(); i++) {
if (model_outputs[i]) {
output_list.outputs.emplace_back(model_outputs[i]);
output_list.expected_num_outputs += 1;
} else {
null_indices.insert(i);
}
}
TF_RETURN_IF_ERROR(dyn_cast<tracing::TracingContext>(func_ctx.get())
->Finalize(&output_list, &func));
scoped_func.reset(func);
for (auto output : output_list.outputs) {
output->Unref();
}
TF_RETURN_IF_ERROR(ctx->RegisterFunction(func));
}
AbstractOperationPtr fn_op(ctx->CreateOperation());
TF_RETURN_IF_ERROR(fn_op->Reset(fn_name, /*raw_device_name=*/nullptr));
for (auto input : inputs) {
TF_RETURN_IF_ERROR(fn_op->AddInput(input));
}
int retvals = outputs.size() - null_indices.size();
vector<AbstractTensorHandle*> fn_outputs(retvals);
TF_RETURN_IF_ERROR(fn_op->Execute(
absl::Span<AbstractTensorHandle*>(fn_outputs.data(), fn_outputs.size()),
&retvals));
int skipped_indices = 0;
for (int i = 0; i < outputs.size(); i++) {
if (!null_indices.contains(i)) {
outputs[i] = fn_outputs[i - skipped_indices];
} else {
skipped_indices += 1;
}
}
TF_RETURN_IF_ERROR(ctx->RemoveFunction(fn_name));
return Status::OK();
} else {
return model(ctx, inputs, outputs, registry);
}
}
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_ContextOptions* opts = TFE_NewContextOptions();
TFE_ContextOptionsSetTfrt(opts, use_tfrt);
*ctx = unwrap(TF_NewEagerExecutionContext(opts, status.get()));
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_DeleteContextOptions(opts);
return Status::OK();
}
} // namespace gradients
} // namespace tensorflow

View File

@ -1,88 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/types.h"
namespace tensorflow {
namespace gradients {
// Get a scalar TensorHandle with given value
Status ScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given float values and dimensions
Status TensorHandleWithDimsFloat(AbstractContext* ctx, float data[],
int64_t dims[], int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions
Status TensorHandleWithDimsInt(AbstractContext* ctx, int data[], int64_t dims[],
int num_dims, AbstractTensorHandle** tensor);
// Places data from `t` into *result_tensor.
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilFloat(AbstractContext* ctx,
float vals[], int64_t dims[],
int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data and dims.
AbstractTensorHandlePtr GetTensorHandleUtilInt(AbstractContext* ctx, int vals[],
int64_t dims[], int num_dims);
// Util function that wraps an AbstractTensorHandle* with given data.
AbstractTensorHandlePtr GetScalarTensorHandleUtil(AbstractContext* ctx,
float val);
// Performs gradient update for each weight using given learning rate.
Status UpdateWeights(AbstractContext* ctx,
std::vector<AbstractTensorHandle*>& grads,
std::vector<AbstractTensorHandle*>& weights,
AbstractTensorHandle* learning_rate);
using Model = std::function<Status(
AbstractContext*, absl::Span<AbstractTensorHandle* const>,
absl::Span<AbstractTensorHandle*>, const GradientRegistry&)>;
// Runs given model in either graph or eager mode depending on value of
// use_function.
Status RunModel(Model model, AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs, bool use_function,
const GradientRegistry& registry);
// Builds context and returns inside *ctx.
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
} // namespace gradients
} // namespace tensorflow

View File

@ -1,602 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include "tensorflow/c/experimental/gradients/math_grad.h"
#include "tensorflow/c/experimental/gradients/nn_grad.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/errors.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using tensorflow::TF_StatusPtr;
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
CHECK_EQ(errors::OK, s.code()) << s.error_message();
// Computing numerical gradients with TensorFloat-32 is numerically
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
// low tolerances
enable_tensor_float_32_execution(false);
}
};
Status RegisterGradients(GradientRegistry* registry) {
TF_RETURN_IF_ERROR(registry->Register("Add", AddRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Exp", ExpRegisterer));
TF_RETURN_IF_ERROR(registry->Register("MatMul", MatMulRegisterer));
TF_RETURN_IF_ERROR(registry->Register("Relu", ReluRegisterer));
TF_RETURN_IF_ERROR(
registry->Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
return Status::OK();
}
TEST_P(CppGradients, TestMatMulGrad) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
float B_vals[] = {.5f, -1.0f, 1.0f, 1.0f};
int64_t B_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
AbstractTensorHandlePtr B =
GetTensorHandleUtilFloat(ctx.get(), B_vals, B_dims, num_dims);
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
* tape.watch(A)
* tape.watch(B)
* Y = AB
* outputs = tape.gradient(Y, [A, B])
*/
std::vector<AbstractTensorHandle*> outputs(2);
s = RunModel(MatMulGradModel, ctx.get(), {A.get(), B.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float expected_dA[4] = {-.5f, 2.0f, -.5f, 2.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dA[j], tolerance);
}
TF_Tensor* dB_tensor;
s = GetValue(outputs[1], &dB_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(dB_tensor),
TF_TensorByteSize(dB_tensor));
float expected_dB[4] = {4.0f, 4.0f, 6.0f, 6.0f};
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dB[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(dA_tensor);
TF_DeleteTensor(dB_tensor);
}
TEST_P(CppGradients, TestMNISTForward) {
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t dims_y[] = {2};
num_dims = sizeof(dims_y) / sizeof(dims_y[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, dims, num_dims);
GradientRegistry registry;
// Run the Forward Pass
std::vector<AbstractTensorHandle*> outputs(2);
Status s =
RunModel(MNISTForwardModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[4] = {3.6f, -6.0f, 10.2f, -17.0f};
float tolerance = 1e-3;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
TF_Tensor* loss_vals_tensor;
s = GetValue(outputs[1], &loss_vals_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
TF_TensorByteSize(loss_vals_tensor));
float expected_losses[2] = {9.6f, 27.2f};
for (int j = 0; j < 2; j++) {
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(scores_tensor);
TF_DeleteTensor(loss_vals_tensor);
}
TEST_P(CppGradients, TestMNISTForward2) {
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t X_dims[] = {3, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1, 1};
int64_t y_dims[] = {3};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
GradientRegistry registry;
// Run the Forward Pass
std::vector<AbstractTensorHandle*> outputs(2);
Status s =
RunModel(MNISTForwardModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[6] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[6] = {3.6f, -6.0f, 10.2f, -17.0f, 16.8f, -28.0f};
float tolerance = 1e-3;
for (int j = 0; j < 6; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
TF_Tensor* loss_vals_tensor;
s = GetValue(outputs[1], &loss_vals_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(loss_vals_tensor),
TF_TensorByteSize(loss_vals_tensor));
float expected_losses[3] = {9.6f, 27.2f, 44.8f};
for (int j = 0; j < 3; j++) {
ASSERT_NEAR(result_data[j], expected_losses[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
TF_DeleteTensor(scores_tensor);
TF_DeleteTensor(loss_vals_tensor);
}
TEST_P(CppGradients, TestMatMulTranspose) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f};
int64_t X_dims[] = {2, 3};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
GradientRegistry registry;
// Run the MatMul Op
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(MatMulTransposeModel, ctx.get(), {X.get(), W1.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Verify the Results
TF_Tensor* scores_tensor;
s = GetValue(outputs[0], &scores_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[6] = {0};
memcpy(&result_data[0], TF_TensorData(scores_tensor),
TF_TensorByteSize(scores_tensor));
float expected_scores[6] = {13.0f, 18.0f, 17.0f, 24.0f, 21.0f, 30.0f};
float tolerance = 1e-3;
for (int j = 0; j < 6; j++) {
ASSERT_NEAR(result_data[j], expected_scores[j], tolerance);
}
}
TEST_P(CppGradients, TestMNISTGrad) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t X_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// W1 = first weights
float W1_vals[] = {-1.0f, 10.0f, .5f, 1.0f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t y_dims[] = {2};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
// Register Grads
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
/* Pseudo-code:
*
*
* tape.watch(W1)
* tape.watch(W2)
* mm = X*W1
* hidden = Relu(mm)
* scores = W2*hidden
* loss = SoftmaxLoss(scores, y)
* outputs = tape.gradient(loss, [A, B])
*
*/
std::vector<AbstractTensorHandle*> outputs(3);
s = RunModel(MNISTGradModel, ctx.get(),
{X.get(), W1.get(), W2.get(), y.get()}, absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float tolerance = 1e-3;
TF_Tensor* dW1_tensor;
s = GetValue(outputs[0], &dW1_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dW1_tensor),
TF_TensorByteSize(dW1_tensor));
float expected_dW1[4] = {0.0f, 3.2f, 0.0f, 4.8f};
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dW1[j], tolerance);
}
TF_Tensor* dW2_tensor;
s = GetValue(outputs[1], &dW2_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
memcpy(&result_data[0], TF_TensorData(dW2_tensor),
TF_TensorByteSize(dW2_tensor));
float expected_dW2[4] = {0.0f, 0.0f, 46.0f, -46.0f}; // dLoss
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], expected_dW2[j], tolerance);
}
outputs[0]->Unref();
outputs[1]->Unref();
outputs[2]->Unref();
TF_DeleteTensor(dW1_tensor);
TF_DeleteTensor(dW2_tensor);
}
TEST_P(CppGradients, TestScalarMul) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
AbstractTensorHandlePtr eta;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = ScalarTensorHandle(ctx.get(), 1.5f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
eta.reset(x_raw);
}
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t A_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr A =
GetTensorHandleUtilFloat(ctx.get(), A_vals, A_dims, num_dims);
GradientRegistry registry;
std::vector<AbstractTensorHandle*> outputs(1);
Status s = RunModel(ScalarMulModel, ctx.get(), {eta.get(), A.get()},
absl::MakeSpan(outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
TF_Tensor* dA_tensor;
s = GetValue(outputs[0], &dA_tensor);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
float result_data[4] = {0};
memcpy(&result_data[0], TF_TensorData(dA_tensor),
TF_TensorByteSize(dA_tensor));
float tolerance = 1e-3;
float eta_val = 1.5f;
for (int j = 0; j < 4; j++) {
ASSERT_NEAR(result_data[j], eta_val * A_vals[j], tolerance);
}
outputs[0]->Unref();
TF_DeleteTensor(dA_tensor);
}
TEST_P(CppGradients, TestMNIST_Training) {
bool use_function = !std::get<2>(GetParam());
if (use_function) {
// TODO(b/168850692): Enable this.
GTEST_SKIP() << "Can't take gradient of "
"SparseSoftmaxCrossEntropyWithLogits in tracing mode.";
}
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
AbstractContextPtr ctx;
{
AbstractContext* ctx_raw = nullptr;
Status s =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx.reset(ctx_raw);
}
// X = data
float X_vals[] = {1.0f, 2.0f, 3.0f, 4.0f};
int64_t X_dims[] = {2, 2};
int num_dims = 2;
AbstractTensorHandlePtr X =
GetTensorHandleUtilFloat(ctx.get(), X_vals, X_dims, num_dims);
// TODO(amturati): use random initializer for weights instead of
// constant values.
// W1 = first weights
float W1_vals[] = {-.01f, 0.4f, 0.5f, -.2f};
int64_t dims[] = {2, 2};
AbstractTensorHandlePtr W1 =
GetTensorHandleUtilFloat(ctx.get(), W1_vals, dims, num_dims);
// W2 = second weights
float W2_vals[] = {.1f, .2f, .3f, -.5f};
AbstractTensorHandlePtr W2 =
GetTensorHandleUtilFloat(ctx.get(), W2_vals, dims, num_dims);
// y = labels
int y_vals[] = {1, 1};
int64_t y_dims[] = {2};
num_dims = sizeof(y_dims) / sizeof(y_dims[0]);
AbstractTensorHandlePtr y =
GetTensorHandleUtilInt(ctx.get(), y_vals, y_dims, num_dims);
// Register Grads
GradientRegistry registry;
Status s = RegisterGradients(&registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Prepare for training
std::vector<AbstractTensorHandle*> weights;
weights.push_back(W1.get());
weights.push_back(W2.get());
// Set learning rate to be 1e-1
AbstractTensorHandle* learning_rate = nullptr;
s = ScalarTensorHandle(ctx.get(), 1e-1, &learning_rate);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Train
int num_iters = 10;
std::vector<AbstractTensorHandle*> mnist_outputs(3);
std::vector<AbstractTensorHandle*> grads(2);
for (int i = 0; i < num_iters; i++) {
// Run Forward Pass
s = RunModel(MNISTGradModel, ctx.get(),
{X.get(), weights[0], weights[1], y.get()},
absl::MakeSpan(mnist_outputs),
/*use_function=*/!std::get<2>(GetParam()), registry);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
// Fill grads
grads[0] = mnist_outputs[0];
grads[1] = mnist_outputs[1];
// Gradient Update
s = UpdateWeights(ctx.get(), grads, weights, learning_rate);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
}
grads[0]->Unref(); // release W1_grad
grads[1]->Unref(); // release W2_grad
mnist_outputs[2]->Unref(); // release loss
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*executing_eagerly*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -1,244 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/eager/mnist_gradients_testutil.h"
#include <memory>
#include "absl/container/flat_hash_set.h"
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/eager/gradients_util.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
namespace tensorflow {
namespace gradients {
namespace internal {
using std::vector;
//===================== Test Models to run =========================
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
auto tape = new Tape(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
std::vector<AbstractTensorHandle*> add_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(
ops::Add(tape_ctx.get(), inputs, absl::MakeSpan(add_outputs), "Add"));
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/add_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto add_output : add_outputs) {
add_output->Unref();
}
delete tape;
return Status::OK();
}
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MatMulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
auto tape = new Tape(/*persistent=*/false);
tape->Watch(inputs[0]); // Watch x.
tape->Watch(inputs[1]); // Watch y.
vector<AbstractTensorHandle*> mm_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), inputs,
absl::MakeSpan(mm_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute x*y.
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/mm_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto mm_output : mm_outputs) {
mm_output->Unref();
}
delete tape;
return Status::OK();
}
// Model to run 2-layer net
Status MNISTForwardModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
/**
* We will trace a 2-layer fully connected network for an MNIST model:
*
* def mnist_forward(X, W1, W2, y_labels):
* mm_out_1 = tf.matmul(X,W1)
* hidden_layer = tf.nn.relu(mm_out_1)
* scores = tf.matmul(hidden_layer,W2)
* softmax =
* tf.nn.sparse_softmax_cross_entropy_with_logits(scores,
* y_labels)
* return scores, softmax
*
* Use this convention for inputs:
*
* inputs = [X, W1, W2, y_labels]
*
*/
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
vector<AbstractTensorHandle*> temp_outputs(1);
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, absl::MakeSpan(temp_outputs),
"matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
TF_RETURN_IF_ERROR(ops::Relu(ctx, {temp_outputs[0]},
absl::MakeSpan(temp_outputs),
"relu")); // Compute Relu(X*W1)
TF_RETURN_IF_ERROR(ops::MatMul(
ctx, {temp_outputs[0], W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false)); // Compute W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
ctx, {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmax_loss")); // Compute Softmax(Scores,labels)
AbstractTensorHandle* loss_vals = temp_outputs[0];
outputs[0] = scores;
outputs[1] = loss_vals;
return Status::OK();
}
Status MatMulTransposeModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
TF_RETURN_IF_ERROR(ops::MatMul(ctx, {X, W1}, outputs, "matmul0",
/*transpose_a=*/true,
/*transpose_b=*/false)); // Compute X*W1
return Status::OK();
}
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
AbstractTensorHandle* X = inputs[0];
AbstractTensorHandle* W1 = inputs[1];
AbstractTensorHandle* W2 = inputs[2];
AbstractTensorHandle* y_labels = inputs[3];
auto tape = new Tape(/*persistent=*/true);
tape->Watch(X); // Watch X.
tape->Watch(W1); // Watch W1.
tape->Watch(W2); // Watch W1.
vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, tape, registry));
TF_RETURN_IF_ERROR(ops::MatMul(tape_ctx.get(), {X, W1},
absl::MakeSpan(temp_outputs), "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false)); // Compute X*W1
AbstractTensorHandle* mm = temp_outputs[0];
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), {mm},
absl::MakeSpan(temp_outputs), // Relu(X*W1)
"relu0"));
AbstractTensorHandle* hidden = temp_outputs[0];
TF_RETURN_IF_ERROR(ops::MatMul(
tape_ctx.get(), {hidden, W2}, absl::MakeSpan(temp_outputs), "matmul1",
/*transpose_a=*/false, /*transpose_b=*/false)); // W2*Relu(X*W1)
AbstractTensorHandle* scores = temp_outputs[0];
temp_outputs.resize(2);
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), {scores, y_labels}, absl::MakeSpan(temp_outputs),
"softmaxloss")); // W2*Relu(X*W1)
AbstractTensorHandle* loss = temp_outputs[0];
TF_RETURN_IF_ERROR(tape->ComputeGradient(ctx, /*targets=*/{loss},
/*sources=*/{W1, W2},
/*output_gradients=*/{},
outputs.subspan(0, 2)));
// Only release 2nd temp output as first holds loss values.
temp_outputs[1]->Unref();
outputs[2] = loss;
delete tape;
return Status::OK();
}
Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
return ops::Mul(ctx, inputs, outputs,
"scalarMul0"); // Compute eta*A
}
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
return ops::MatMul(ctx, inputs, outputs, "matmul0",
/*transpose_a=*/false,
/*transpose_b=*/false); // Compute X*W1
}
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry) {
return ops::Mul(ctx, inputs, outputs,
"mul0"); // Compute x*y
}
// ============================= End Models ================================
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -1,90 +0,0 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#define TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_
#include <memory>
#include "absl/types/span.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/gradients.h"
#include "tensorflow/c/eager/gradients_internal.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/experimental/ops/math_ops.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/core/lib/llvm_rtti/llvm_rtti.h"
#include "tensorflow/core/platform/status.h"
namespace tensorflow {
namespace gradients {
namespace internal {
// Computes
// y = inputs[0] + inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status AddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes
// y = inputs[0] * inputs[1]
// return grad(y, {inputs[0], inputs[1]})
Status MatMulGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes 2-layer Neural Network with Softmax Loss.
Status MNISTForwardModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Computes MatMul with first matrix tranposed.
Status MatMulTransposeModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify Multi-grad functionality for MNIST
Status MNISTGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
// Test Model to verify scalar-tensor multiplication Op
Status ScalarMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
Status MatMulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
Status MulModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs,
const GradientRegistry& registry);
} // namespace internal
} // namespace gradients
} // namespace tensorflow
#endif // TENSORFLOW_C_EAGER_MNIST_GRADIENTS_TESTUTIL_H_

View File

@ -68,7 +68,7 @@ TEST_P(UnifiedAPI, TestTensorShapeScalar) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 2.0f, &x_raw);
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}
@ -120,8 +120,8 @@ TEST_P(UnifiedAPI, TestTensorShape2x4) {
AbstractTensorHandle* x_raw = nullptr;
float data[] = {0., 0., 0., 0., 0., 0., 0., 0};
int64_t dim_sizes[] = {2, 4};
Status s =
TestTensorHandleWithDimsFloat(ctx.get(), data, dim_sizes, 2, &x_raw);
Status s = TestTensorHandleWithDims<float, TF_FLOAT>(ctx.get(), data,
dim_sizes, 2, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}

View File

@ -130,49 +130,6 @@ Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx) {
return Status::OK();
}
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager = TestScalarTensorHandle(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsFloat(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDimsInt(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);

View File

@ -17,6 +17,10 @@ limitations under the License.
#include "tensorflow/c/eager/abstract_context.h"
#include "tensorflow/c/eager/abstract_tensor_handle.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/c/tf_tensor.h"
#include "tensorflow/core/platform/status.h"
@ -49,19 +53,38 @@ Status RunModel(Model model, AbstractContext* ctx,
Status BuildImmediateExecutionContext(bool use_tfrt, AbstractContext** ctx);
// Get a Scalar TensorHandle with given float value.
Status TestScalarTensorHandle(AbstractContext* ctx, float value,
AbstractTensorHandle** tensor);
// Return a tensor handle with given type, values and dimensions.
template <class T, TF_DataType datatype>
Status TestTensorHandleWithDims(AbstractContext* ctx, const T* data,
const int64_t* dims, int num_dims,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestTensorHandleWithDims<T, datatype>(eager_ctx, data, dims, num_dims);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Get a Matrix TensorHandle with given float values and dimensions.
Status TestTensorHandleWithDimsFloat(AbstractContext* ctx, float* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor);
// Get a TensorHandle with given int values and dimensions
Status TestTensorHandleWithDimsInt(AbstractContext* ctx, int* data,
int64_t* dims, int num_dims,
AbstractTensorHandle** tensor);
// Return a scalar tensor handle with given value.
template <class T, TF_DataType datatype>
Status TestScalarTensorHandle(AbstractContext* ctx, const T value,
AbstractTensorHandle** tensor) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
TFE_Context* eager_ctx =
TF_ExecutionContextGetTFEContext(wrap(ctx), status.get());
TF_RETURN_IF_ERROR(StatusFromTF_Status(status.get()));
TFE_TensorHandle* input_eager =
TestScalarTensorHandle<T, datatype>(eager_ctx, value);
*tensor =
unwrap(TF_CreateAbstractTensorFromEagerTensor(input_eager, status.get()));
return Status::OK();
}
// Places data from `t` into *result_tensor.
Status GetValue(AbstractTensorHandle* t, TF_Tensor** result_tensor);

View File

@ -190,6 +190,7 @@ tf_cuda_cc_test(
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:nn_ops",
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + if_libtpu(
@ -222,3 +223,28 @@ tf_cuda_cc_test(
if_true = [],
),
)
tf_cuda_cc_test(
name = "array_grad_test",
size = "small",
srcs = [
"array_grad_test.cc",
],
args = ["--heap_check=local"], # TODO(b/174752220): Remove
linkstatic = tf_kernel_tests_linkstatic(),
tags = tf_cuda_tests_tags() + ["no_cuda_asan"], # b/173654156,
deps = [
":grad_test_helper",
":array_grad",
"//tensorflow/c:tf_status_helper",
"//tensorflow/c/eager:c_api_test_util",
"//tensorflow/c/experimental/gradients/tape:tape_context",
"//tensorflow/c/experimental/ops:array_ops",
"//tensorflow/core/platform:tensor_float_32_utils",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
] + if_libtpu(
if_false = ["//tensorflow/compiler/mlir/tensorflow/c:mlir_c_api_registration"],
if_true = [],
),
)

View File

@ -0,0 +1,133 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/c/experimental/gradients/array_grad.h"
#include "tensorflow/c/eager/c_api_test_util.h"
#include "tensorflow/c/eager/c_api_unified_experimental_internal.h"
#include "tensorflow/c/eager/unified_api_testutil.h"
#include "tensorflow/c/experimental/gradients/grad_test_helper.h"
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/array_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace gradients {
namespace internal {
namespace {
using tensorflow::TF_StatusPtr;
Status IdentityNModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
std::vector<AbstractTensorHandle*> temp_outputs(2);
TF_RETURN_IF_ERROR(
ops::IdentityN(ctx, inputs, absl::MakeSpan(temp_outputs), "IdentityN"));
// Although, `ops::IdentityN` returns 2 tensors, the first tensor isn't needed
// for computing gradient so we could safely drop it.
outputs[0] = temp_outputs[1];
temp_outputs[0]->Unref();
return Status::OK();
}
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
status_ = StatusFromTF_Status(status.get());
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
{
AbstractContext* ctx_raw = nullptr;
status_ =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
immediate_execution_ctx_.reset(ctx_raw);
}
// Computing numerical gradients with TensorFloat-32 is numerically
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
// low tolerances
enable_tensor_float_32_execution(false);
}
AbstractContextPtr immediate_execution_ctx_;
GradientRegistry registry_;
Status status_;
public:
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
bool UseFunction() const { return std::get<2>(GetParam()); }
};
TEST_P(CppGradients, TestIdentityNGrad) {
// This test is interesting because the current implementation of GradientTape
// would return [0, 1] whereas we use build_default_zeros_grads=false here
// so we get back [nullptr, 1].
AbstractTensorHandlePtr x1;
{
AbstractTensorHandle* x1_raw = nullptr;
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 1.0f, &x1_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x1.reset(x1_raw);
}
AbstractTensorHandlePtr x2;
{
AbstractTensorHandle* x2_raw = nullptr;
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 1.0f, &x2_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x2.reset(x2_raw);
}
status_ = registry_.Register("IdentityN", IdentityNRegisterer);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
auto IdentityNGradModel = BuildGradModel(IdentityNModel, registry_);
std::vector<AbstractTensorHandle*> outputs(2);
status_ =
RunModel(IdentityNGradModel, immediate_execution_ctx_.get(),
{x1.get(), x2.get()}, absl::MakeSpan(outputs), UseFunction());
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
EXPECT_EQ(outputs[0], nullptr);
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], {1.0f}, /*dims*/ {},
/*abs_error*/ 0));
outputs[1]->Unref();
}
#ifdef PLATFORM_GOOGLE
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*use_function*/ ::testing::Values(true, false)));
#else
INSTANTIATE_TEST_SUITE_P(
UnifiedCAPI, CppGradients,
::testing::Combine(::testing::Values("graphdef", "mlir"),
/*tfrt*/ ::testing::Values(false),
/*use_function*/ ::testing::Values(true, false)));
#endif
} // namespace
} // namespace internal
} // namespace gradients
} // namespace tensorflow

View File

@ -101,7 +101,7 @@ TEST_P(CustomGradientTest, ExpWithPassThroughGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
Status s = TestScalarTensorHandle(ctx.get(), 1.0f, &x_raw);
Status s = TestScalarTensorHandle<float, TF_FLOAT>(ctx.get(), 1.0f, &x_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
x.reset(x_raw);
}

View File

@ -90,9 +90,9 @@ class CppGradients
{
AbstractContext* ctx_raw = nullptr;
Status s =
status_ =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
immediate_execution_ctx_.reset(ctx_raw);
}
@ -115,8 +115,8 @@ TEST_P(CppGradients, TestAddGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -124,12 +124,14 @@ TEST_P(CppGradients, TestAddGrad) {
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
y.reset(y_raw);
}
// TODO(srbs): Rename ops::Add to ops::AddV2 and AddRegister to
// AddV2Registerer.
status_ = registry_.Register("AddV2", AddRegisterer);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
@ -142,8 +144,8 @@ TEST_P(CppGradients, TestExpGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -167,8 +169,8 @@ TEST_P(CppGradients, TestMatMulGrad) {
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
status_ = TestTensorHandleWithDimsFloat(immediate_execution_ctx_.get(),
A_vals, A_dims, 2, &A_raw);
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
A.reset(A_raw);
}
@ -178,8 +180,8 @@ TEST_P(CppGradients, TestMatMulGrad) {
AbstractTensorHandlePtr B;
{
AbstractTensorHandle* B_raw;
status_ = TestTensorHandleWithDimsFloat(immediate_execution_ctx_.get(),
B_vals, B_dims, 2, &B_raw);
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
B.reset(B_raw);
}
@ -204,12 +206,77 @@ TEST_P(CppGradients, TestMatMulGrad) {
}
}
TEST_P(CppGradients, TestMatMulGradManual) {
float A_vals[] = {1.0f, 2.0f, 3.0f, 4.0f, 5.0f, 6.0f, 7.0f, 8.0f, 9.0f};
int64_t A_dims[] = {3, 3};
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
A.reset(A_raw);
}
float B_vals[] = {9.0f, 8.0f, 7.0f, 6.0f, 5.0f, 4.0f, 3.0f, 2.0f, 1.0f};
int64_t B_dims[] = {3, 3};
AbstractTensorHandlePtr B;
{
AbstractTensorHandle* B_raw;
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), B_vals, B_dims, 2, &B_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
B.reset(B_raw);
}
status_ = registry_.Register("MatMul", MatMulRegisterer);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
bool transpose_a_vals[] = {false, false, true, true};
bool transpose_b_vals[] = {false, true, false, true};
float dA_vals[4][9] = {{24, 15, 6, 24, 15, 6, 24, 15, 6},
{18, 15, 12, 18, 15, 12, 18, 15, 12},
{24, 24, 24, 15, 15, 15, 6, 6, 6},
{18, 18, 18, 15, 15, 15, 12, 12, 12}};
float dB_vals[4][9] = {{12, 12, 12, 15, 15, 15, 18, 18, 18},
{12, 15, 18, 12, 15, 18, 12, 15, 18},
{6, 6, 6, 15, 15, 15, 24, 24, 24},
{6, 15, 24, 6, 15, 24, 6, 15, 24}};
for (int i{}; i < 4; ++i) {
bool transpose_a = transpose_a_vals[i];
bool transpose_b = transpose_b_vals[i];
Model MatMulModel =
[transpose_a, transpose_b](
AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) -> Status {
return ops::MatMul(ctx, inputs, outputs, "MatMul", transpose_a,
transpose_b);
};
Model MatMulGradModel = BuildGradModel(MatMulModel, registry_);
std::vector<AbstractTensorHandle*> outputs(2);
status_ =
RunModel(MatMulGradModel, immediate_execution_ctx_.get(),
{A.get(), B.get()}, absl::MakeSpan(outputs), UseFunction());
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], dA_vals[i],
/*dims*/ {3, 3},
/*abs_error*/ 0));
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[1], dB_vals[i],
/*dims*/ {3, 3},
/*abs_error*/ 0));
outputs[0]->Unref();
outputs[1]->Unref();
}
}
TEST_P(CppGradients, TestSqrtGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -226,8 +293,8 @@ TEST_P(CppGradients, TestNegGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -244,8 +311,8 @@ TEST_P(CppGradients, TestSubGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -253,8 +320,8 @@ TEST_P(CppGradients, TestSubGrad) {
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
y.reset(y_raw);
}
@ -271,8 +338,8 @@ TEST_P(CppGradients, TestMulGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -280,8 +347,8 @@ TEST_P(CppGradients, TestMulGrad) {
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
y.reset(y_raw);
}
@ -298,8 +365,8 @@ TEST_P(CppGradients, TestLog1pGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -321,8 +388,8 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
AbstractTensorHandlePtr x;
{
AbstractTensorHandle* x_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &x_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &x_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
x.reset(x_raw);
}
@ -330,8 +397,8 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
AbstractTensorHandlePtr y;
{
AbstractTensorHandle* y_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 2.0f, &y_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 2.0f, &y_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
y.reset(y_raw);
}
@ -344,8 +411,8 @@ TEST_P(CppGradients, TestDivNoNanGrad) {
AbstractTensorHandlePtr z;
{
AbstractTensorHandle* z_raw = nullptr;
status_ =
TestScalarTensorHandle(immediate_execution_ctx_.get(), 0.0f, &z_raw);
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 0.0f, &z_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
z.reset(z_raw);
}

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/c/experimental/gradients/tape/tape_context.h"
#include "tensorflow/c/experimental/ops/nn_ops.h"
#include "tensorflow/c/tf_status_helper.h"
#include "tensorflow/core/platform/tensor_float_32_utils.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
@ -35,28 +36,6 @@ Status ReluModel(AbstractContext* ctx,
return ops::Relu(ctx, inputs, outputs, "Relu");
}
Status ReluGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("Relu", ReluRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]);
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::Relu(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "ReluGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status SparseSoftmaxCrossEntropyWithLogitsModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
@ -73,80 +52,38 @@ Status SparseSoftmaxCrossEntropyWithLogitsModel(
return Status::OK();
}
Status SparseSoftmaxCrossEntropyWithLogitsGradModel(
AbstractContext* ctx, absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(
registry.Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]); // Watch score.
tape.Watch(inputs[1]); // Watch label.
std::vector<AbstractTensorHandle*> temp_outputs(2);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::SparseSoftmaxCrossEntropyWithLogits(
tape_ctx.get(), inputs, absl::MakeSpan(temp_outputs),
"SparseSoftmaxCrossEntropyWithLogitsGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
Status BiasAddModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
return ops::BiasAdd(ctx, inputs, outputs, "BiasAdd");
}
Status BiasAddGradModel(AbstractContext* ctx,
absl::Span<AbstractTensorHandle* const> inputs,
absl::Span<AbstractTensorHandle*> outputs) {
GradientRegistry registry;
TF_RETURN_IF_ERROR(registry.Register("BiasAdd", BiasAddRegisterer));
Tape tape(/*persistent=*/false);
tape.Watch(inputs[0]); // Watch A.
tape.Watch(inputs[1]); // Watch Bias.
std::vector<AbstractTensorHandle*> temp_outputs(1);
AbstractContextPtr tape_ctx(new TapeContext(ctx, &tape, registry));
TF_RETURN_IF_ERROR(ops::BiasAdd(tape_ctx.get(), inputs,
absl::MakeSpan(temp_outputs), "BiasAddGrad"));
TF_RETURN_IF_ERROR(tape.ComputeGradient(ctx, /*targets=*/temp_outputs,
/*sources=*/inputs,
/*output_gradients=*/{}, outputs));
for (auto temp_output : temp_outputs) {
temp_output->Unref();
}
return Status::OK();
}
class CppGradients
: public ::testing::TestWithParam<std::tuple<const char*, bool, bool>> {
protected:
void SetUp() override {
TF_StatusPtr status(TF_NewStatus());
TF_SetTracingImplementation(std::get<0>(GetParam()), status.get());
Status s = StatusFromTF_Status(status.get());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = StatusFromTF_Status(status.get());
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
{
AbstractContext* ctx_raw = nullptr;
Status s =
status_ =
BuildImmediateExecutionContext(std::get<1>(GetParam()), &ctx_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
ctx_.reset(ctx_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
immediate_execution_ctx_.reset(ctx_raw);
}
// Computing numerical gradients with TensorFloat-32 is numerically
// unstable. Some forward pass tests also fail with TensorFloat-32 due to
// low tolerances
enable_tensor_float_32_execution(false);
}
AbstractContextPtr ctx_;
AbstractContextPtr immediate_execution_ctx_;
GradientRegistry registry_;
Status status_;
public:
bool UseMlir() const { return strcmp(std::get<0>(GetParam()), "mlir") == 0; }
@ -154,34 +91,41 @@ class CppGradients
};
TEST_P(CppGradients, TestReluGrad) {
status_ = registry_.Register("Relu", ReluRegisterer);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
auto ReluGradModel = BuildGradModel(ReluModel, registry_);
float X_vals[] = {1.0f, 2.0f, 3.0f, -5.0f, -4.0f, -3.0f, 2.0f, 10.0f, -1.0f};
int64_t X_dims[] = {3, 3};
AbstractTensorHandlePtr X;
{
AbstractTensorHandle* X_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
X.reset(X_raw);
}
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
ReluModel, ReluGradModel, ctx_.get(), {X.get()}, UseFunction()));
ReluModel, ReluGradModel, immediate_execution_ctx_.get(), {X.get()},
UseFunction()));
// Mathematically, Relu isn't differentiable at `0`. So `gradient_checker`
// does not work with it.
AbstractTensorHandlePtr Y;
{
AbstractTensorHandle* Y_raw;
Status s = TestScalarTensorHandle(ctx_.get(), 0.0f, &Y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = TestScalarTensorHandle<float, TF_FLOAT>(
immediate_execution_ctx_.get(), 0.0f, &Y_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
Y.reset(Y_raw);
}
std::vector<AbstractTensorHandle*> outputs(1);
auto s = RunModel(ReluGradModel, ctx_.get(), {Y.get()},
absl::MakeSpan(outputs), UseFunction());
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = RunModel(ReluGradModel, immediate_execution_ctx_.get(), {Y.get()},
absl::MakeSpan(outputs), UseFunction());
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
ASSERT_NO_FATAL_FAILURE(CheckTensorValue(outputs[0], {0.0f}, /*dims*/ {},
/*abs_error*/ 0));
outputs[0]->Unref();
@ -200,27 +144,31 @@ TEST_P(CppGradients, TestSparseSoftmaxCrossEntropyWithLogitsGrad) {
AbstractTensorHandlePtr X;
{
AbstractTensorHandle* X_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), X_vals, X_dims, 2, &X_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), X_vals, X_dims, 2, &X_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
X.reset(X_raw);
}
// Label
int Y_vals[] = {1, 0, 1};
int32_t Y_vals[] = {1, 0, 1};
int64_t Y_dims[] = {3};
AbstractTensorHandlePtr Y;
{
AbstractTensorHandle* Y_raw;
Status s =
TestTensorHandleWithDimsInt(ctx_.get(), Y_vals, Y_dims, 1, &Y_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = TestTensorHandleWithDims<int32_t, TF_INT32>(
immediate_execution_ctx_.get(), Y_vals, Y_dims, 1, &Y_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
Y.reset(Y_raw);
}
status_ = registry_.Register("SparseSoftmaxCrossEntropyWithLogits",
SparseSoftmaxCrossEntropyWithLogitsRegisterer);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
SparseSoftmaxCrossEntropyWithLogitsModel,
SparseSoftmaxCrossEntropyWithLogitsGradModel, ctx_.get(),
{X.get(), Y.get()}, UseFunction()));
BuildGradModel(SparseSoftmaxCrossEntropyWithLogitsModel, registry_),
immediate_execution_ctx_.get(), {X.get(), Y.get()}, UseFunction()));
}
TEST_P(CppGradients, TestBiasAddGrad) {
@ -234,9 +182,9 @@ TEST_P(CppGradients, TestBiasAddGrad) {
AbstractTensorHandlePtr A;
{
AbstractTensorHandle* A_raw;
Status s =
TestTensorHandleWithDimsFloat(ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), A_vals, A_dims, 2, &A_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
A.reset(A_raw);
}
// Bias
@ -245,15 +193,18 @@ TEST_P(CppGradients, TestBiasAddGrad) {
AbstractTensorHandlePtr Bias;
{
AbstractTensorHandle* Bias_raw;
Status s = TestTensorHandleWithDimsFloat(ctx_.get(), Bias_vals, Bias_dims,
1, &Bias_raw);
ASSERT_EQ(errors::OK, s.code()) << s.error_message();
status_ = TestTensorHandleWithDims<float, TF_FLOAT>(
immediate_execution_ctx_.get(), Bias_vals, Bias_dims, 1, &Bias_raw);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
Bias.reset(Bias_raw);
}
status_ = registry_.Register("BiasAdd", BiasAddRegisterer);
ASSERT_EQ(errors::OK, status_.code()) << status_.error_message();
ASSERT_NO_FATAL_FAILURE(CompareNumericalAndAutodiffGradients(
BiasAddModel, BiasAddGradModel, ctx_.get(), {A.get(), Bias.get()},
UseFunction()));
BiasAddModel, BuildGradModel(BiasAddModel, registry_),
immediate_execution_ctx_.get(), {A.get(), Bias.get()}, UseFunction()));
}
#ifdef PLATFORM_GOOGLE

View File

@ -141,6 +141,72 @@ AffineMap GetTransposeMapForReduction(MLIRContext* context, int rank,
return inversePermutation(map);
}
/// Returns true if the given `dimensionNumbers` from a mhlo.convolution op
/// follows a canonical form:
///
/// * Input dimensions have order: (batch_count, spatial_dims,
/// input_channel_count).
/// * Filter dimensions have order: (spatial_dims, input_channel_count,
/// output_channel_count).
/// * Output dimensions have order: (batch_count, spatial_dims,
/// output_channel_count).
template <typename DimensionNumbersTy>
static bool HasCanonicalDimensionNumbers(
const DimensionNumbersTy& dimension_numbers) {
const int input_spatial_rank =
llvm::size(dimension_numbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
dimension_numbers.input_feature_dimension().getInt() !=
(input_spatial_rank + 1)) {
return false;
}
const int kernel_spatial_rank =
llvm::size(dimension_numbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
kernel_spatial_rank ||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
(kernel_spatial_rank + 1)) {
return false;
}
const int output_spatial_rank =
llvm::size(dimension_numbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
dimension_numbers.output_feature_dimension().getInt() !=
(output_spatial_rank + 1)) {
return false;
}
if (input_spatial_rank != output_spatial_rank ||
input_spatial_rank != kernel_spatial_rank) {
return false;
}
auto input_spatial_dim = dimension_numbers.input_spatial_dimensions().begin();
auto kernel_spatial_dim =
dimension_numbers.kernel_spatial_dimensions().begin();
auto output_spatial_dim =
dimension_numbers.output_spatial_dimensions().begin();
// Check spatial dims are ordered correctly.
for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1;
if ((*input_spatial_dim++).getZExtValue() != dim ||
(*output_spatial_dim++).getZExtValue() != dim ||
(*kernel_spatial_dim++).getZExtValue() != i) {
return false;
}
}
return true;
}
template <typename OpTy, bool isLHLO = true>
class PointwiseToLinalgConverter : public OpConversionPattern<OpTy> {
public:
@ -264,61 +330,10 @@ struct ConvToLinalgConverter : public OpConversionPattern<lmhlo::ConvOp> {
public:
using OpConversionPattern<lmhlo::ConvOp>::OpConversionPattern;
// This code has been adapted from IREE's
// (https://github.com/google/iree/) mhlo -> linalg conversion.
LogicalResult matchAndRewrite(
lmhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const final {
// Check validity of dimension information.
if (const mhlo::ConvDimensionNumbers& dimension_numbers =
op.dimension_numbers()) {
const int input_spatial_rank =
llvm::size(dimension_numbers.input_spatial_dimensions());
// The dimensions for input should follow the order of
// batch_count, spatial_dims..., input_feature_count.
if (dimension_numbers.input_batch_dimension().getInt() != 0 ||
dimension_numbers.input_feature_dimension().getInt() !=
(input_spatial_rank + 1))
return failure();
const int kernel_spatial_rank =
llvm::size(dimension_numbers.kernel_spatial_dimensions());
// The dimensions for filter should follow the order of
// spatial_dims..., input_feature_count, num_output_feature_count.
if (dimension_numbers.kernel_input_feature_dimension().getInt() !=
kernel_spatial_rank ||
dimension_numbers.kernel_output_feature_dimension().getInt() !=
(kernel_spatial_rank + 1))
return failure();
const int output_spatial_rank =
llvm::size(dimension_numbers.output_spatial_dimensions());
// The dimensions for output should follow the order of
// batch_count, spatial_dims.., output_feature_count.
if (dimension_numbers.output_batch_dimension().getInt() != 0 ||
dimension_numbers.output_feature_dimension().getInt() !=
(output_spatial_rank + 1))
return failure();
if (input_spatial_rank != output_spatial_rank ||
input_spatial_rank != kernel_spatial_rank)
return failure();
auto input_spatial_dim =
dimension_numbers.input_spatial_dimensions().begin();
auto kernel_spatial_dim =
dimension_numbers.kernel_spatial_dimensions().begin();
auto output_spatial_dim =
dimension_numbers.output_spatial_dimensions().begin();
// Check if spatial dims are ordered correctly.
for (int i = 0; i < input_spatial_rank; ++i) {
const int dim = i + 1;
if ((*input_spatial_dim++).getZExtValue() != dim ||
(*output_spatial_dim++).getZExtValue() != dim ||
(*kernel_spatial_dim++).getZExtValue() != i)
return failure();
}
}
if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
// TODO: LHS dilation for deconvolution not supported yet.
// TODO(jurahul): Window reversal is not supported yet.
@ -1432,6 +1447,80 @@ struct PadOpOnTensorsConversion : public OpConversionPattern<mhlo::PadOp> {
}
};
/// Converts mhlo.conv operation to linalg named op. This only covers normal
/// convolution cases. The op must have canonical dimension numbers. Depthwise
/// convolution and pointwise convolution are not handled in the conversion.
struct NormalConvOpOnTensorsConversion
: public OpConversionPattern<mhlo::ConvOp> {
using OpConversionPattern<mhlo::ConvOp>::OpConversionPattern;
LogicalResult matchAndRewrite(
mhlo::ConvOp op, ArrayRef<Value> args,
ConversionPatternRewriter& rewriter) const override {
if (!HasCanonicalDimensionNumbers(op.dimension_numbers())) return failure();
if (op.feature_group_count() != 1u) return failure();
mhlo::ConvOp::Adaptor adaptor(args);
Location loc = op.getLoc();
Value input = adaptor.lhs();
Value filter = adaptor.rhs();
auto result_type = op.getResult().getType().cast<ShapedType>();
int64_t rank = result_type.getRank();
// Check if padding is zero.
DenseIntElementsAttr padding = op.paddingAttr();
if (padding &&
(!padding.isSplat() || padding.getSplatValue<int64_t>() != 0)) {
return rewriter.notifyMatchFailure(op, "expected no padding");
}
// The output shape is N spatial_dims F.
SmallVector<Value, 8> dyn_sizes;
for (int64_t i = 0, e = rank - 1; i < e; ++i) {
if (!result_type.isDynamicDim(i)) continue;
dyn_sizes.push_back(rewriter.create<DimOp>(loc, input, i));
}
if (result_type.isDynamicDim(rank - 1)) {
dyn_sizes.push_back(rewriter.create<DimOp>(loc, filter, rank - 1));
}
Value init_tensor = rewriter.create<linalg::InitTensorOp>(
loc, dyn_sizes, result_type.getShape(), result_type.getElementType());
auto zero_attr = rewriter.getZeroAttr(result_type.getElementType());
Value zero = rewriter.create<ConstantOp>(loc, zero_attr);
Value zero_tensor =
rewriter.create<linalg::FillOp>(loc, init_tensor, zero).getResult(0);
linalg::LinalgOp res;
Attribute strides = op.window_stridesAttr();
// TODO(ataei): Only support dilated kernel right now. We need to consider
// input dilation for deconvolution cases.
Attribute dilations = op.rhs_dilationAttr();
switch (rank) {
case 3: {
res = rewriter.create<linalg::ConvInputNWCFilterWCFOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, dilations, strides);
break;
}
case 4: {
res = rewriter.create<linalg::ConvInputNHWCFilterHWCFOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, dilations, strides);
break;
}
case 5: {
res = rewriter.create<linalg::ConvInputNDHWCFilterDHWCFOp>(
loc, result_type, ValueRange{input, filter},
ValueRange{zero_tensor}, dilations, strides);
break;
}
default:
return rewriter.notifyMatchFailure(op, "expected 1/2/3D conv op");
}
rewriter.replaceOp(op, res.getOperation()->getResults());
return success();
}
};
void populateLHLOToLinalgConversionPattern(MLIRContext* context,
OwningRewritePatternList* patterns) {
// clang-format off
@ -1656,6 +1745,7 @@ void populateHLOToLinalgConversionPattern(MLIRContext* context,
linalg::BatchMatmulI32I32I32Op>,
DotGeneralOpOnTensorsConversion<FloatType, 32, FloatType, 32,
linalg::BatchMatmulOp>,
NormalConvOpOnTensorsConversion,
ReduceOnTensorsConversion,
PadOpOnTensorsConversion>(context);
// clang-format on

View File

@ -1441,3 +1441,171 @@ func @pad_tensor(%arg0: tensor<12x4xf32>, %arg1: tensor<f32>) -> tensor<18x12xf3
// CHECK: linalg.pad_tensor %[[ARG0]] low[%[[C4]], %[[C5]]] high[%[[C2]], %[[C3]]]
// CHECK: linalg.yield %[[PAD]] : f32
// CHECK: } : tensor<12x4xf32> to tensor<18x12xf32>
// -----
func @linalg.conv_1d_input_nwc_filter_wcf(%arg0: tensor<?x?x?xf32>, %arg1: tensor<?x?x?xf32>)
-> tensor<?x?x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 2 : i64,
input_spatial_dimensions = dense<[1]> : tensor<1xi64>,
kernel_input_feature_dimension = 1 : i64,
kernel_output_feature_dimension = 2 : i64,
kernel_spatial_dimensions = dense<[0]> : tensor<1xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 2 : i64,
output_spatial_dimensions = dense<[1]> : tensor<1xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0], [0]]> : tensor<2x1xi64>,
rhs_dilation = dense<1> : tensor<1xi64>,
window_strides = dense<1> : tensor<1xi64>
} : (tensor<?x?x?xf32>, tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
return %0 : tensor<?x?x?xf32>
}
// CHECK-LABEL: func @linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG1]], %[[C2]] : tensor<?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_1d_input_nwc_filter_wcf
// CHECK-SAME: {dilations = dense<1> : tensor<1xi64>
// CHECK-SAME: strides = dense<1> : tensor<1xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?xf32>, tensor<?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?xf32>) -> tensor<?x?x?xf32>
// -----
func @conv_2d_input_nhwc_filter_hwcf(%arg0: tensor<?x?x?x?xf32>, %arg1: tensor<?x?x?x?xf32>)
-> tensor<?x?x?x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0, 0], [0, 0]]> : tensor<2x2xi64>,
rhs_dilation = dense<1> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
} : (tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
return %0 : tensor<?x?x?x?xf32>
}
// CHECK-LABEL: func @conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG1]], %[[C3]] : tensor<?x?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<1> : tensor<2xi64>
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?xf32>, tensor<?x?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?xf32>) -> tensor<?x?x?x?xf32>
// -----
func @conv_3d_input_ndhwc_filter_dhwcf(%arg0: tensor<?x?x?x?x?xf32>, %arg1: tensor<?x?x?x?x?xf32>)
-> tensor<?x?x?x?x?xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 4 : i64,
input_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>,
kernel_input_feature_dimension = 3 : i64,
kernel_output_feature_dimension = 4 : i64,
kernel_spatial_dimensions = dense<[0, 1, 2]> : tensor<3xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 4 : i64,
output_spatial_dimensions = dense<[1, 2, 3]> : tensor<3xi64>
},
feature_group_count = 1 : i64,
padding = dense<[[0, 0, 0], [0, 0, 0]]> : tensor<2x3xi64>,
rhs_dilation = dense<1> : tensor<3xi64>,
window_strides = dense<1> : tensor<3xi64>
} : (tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
return %0 : tensor<?x?x?x?x?xf32>
}
// CHECK-LABEL: func @conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[C0:.+]] = constant 0 : index
// CHECK: %[[DIM0:.+]] = dim %[[ARG0]], %[[C0]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C1:.+]] = constant 1 : index
// CHECK: %[[DIM1:.+]] = dim %[[ARG0]], %[[C1]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C2:.+]] = constant 2 : index
// CHECK: %[[DIM2:.+]] = dim %[[ARG0]], %[[C2]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C3:.+]] = constant 3 : index
// CHECK: %[[DIM3:.+]] = dim %[[ARG0]], %[[C3]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[C4:.+]] = constant 4 : index
// CHECK: %[[DIM4:.+]] = dim %[[ARG1]], %[[C4]] : tensor<?x?x?x?x?xf32>
// CHECK: %[[INIT:.+]] = linalg.init_tensor [%[[DIM0]], %[[DIM1]], %[[DIM2]], %[[DIM3]], %[[DIM4]]]
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]])
// CHECK: linalg.conv_3d_input_ndhwc_filter_dhwcf
// CHECK-SAME: {dilations = dense<1> : tensor<3xi64>
// CHECK-SAME: strides = dense<1> : tensor<3xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<?x?x?x?x?xf32>, tensor<?x?x?x?x?xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<?x?x?x?x?xf32>) -> tensor<?x?x?x?x?xf32>
// -----
func @conv2d_1452x2223_dilated_valid(%arg0: tensor<1x4x5x2xf32>, %arg1: tensor<2x2x2x3xf32>)
-> tensor<1x2x4x3xf32> {
%0 = "mhlo.convolution"(%arg0, %arg1) {
batch_group_count = 1 : i64,
dimension_numbers = {
input_batch_dimension = 0 : i64,
input_feature_dimension = 3 : i64,
input_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>,
kernel_input_feature_dimension = 2 : i64,
kernel_output_feature_dimension = 3 : i64,
kernel_spatial_dimensions = dense<[0, 1]> : tensor<2xi64>,
output_batch_dimension = 0 : i64,
output_feature_dimension = 3 : i64,
output_spatial_dimensions = dense<[1, 2]> : tensor<2xi64>
},
feature_group_count = 1 : i64,
padding = dense<0> : tensor<2x2xi64>,
rhs_dilation = dense<[2, 1]> : tensor<2xi64>,
window_strides = dense<1> : tensor<2xi64>
} : (tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>) -> tensor<1x2x4x3xf32>
return %0 : tensor<1x2x4x3xf32>
}
// CHECK-LABEL: func @conv2d_1452x2223_dilated_valid
// CHECK-SAME: %[[ARG0:[a-zA-Z0-9_]*]]
// CHECK-SAME: %[[ARG1:[a-zA-Z0-9_]*]]
// CHECK: %[[INIT:.+]] = linalg.init_tensor [1, 2, 4, 3] : tensor<1x2x4x3xf32>
// CHECK: %[[ZERO:.+]] = constant 0.000000e+00 : f32
// CHECK: %[[FILL:.+]] = linalg.fill(%[[INIT]], %[[ZERO]]) : tensor<1x2x4x3xf32>, f32 -> tensor<1x2x4x3xf32>
// CHECK: linalg.conv_2d_input_nhwc_filter_hwcf
// CHECK-SAME: {dilations = dense<[2, 1]> : tensor<2xi64>
// CHECK-SAME: strides = dense<1> : tensor<2xi64>}
// CHECK-SAME: ins(%[[ARG0]], %[[ARG1]] : tensor<1x4x5x2xf32>, tensor<2x2x2x3xf32>)
// CHECK-SAME: outs(%[[FILL]] : tensor<1x2x4x3xf32>) -> tensor<1x2x4x3xf32>

View File

@ -2210,6 +2210,70 @@ static LogicalResult Verify(UnidirectionalSequenceLSTMOp op) {
"UnidirectionalSequenceLSTMOp expected to have two stateful operands");
}
LogicalResult UnidirectionalSequenceLSTMOp::inferReturnTypes(
MLIRContext *, Optional<Location>, ValueRange operands, DictionaryAttr attr,
RegionRange, SmallVectorImpl<Type> &inferredReturnTypes) {
Value input = operands[0];
auto input_type = input.getType().dyn_cast_or_null<RankedTensorType>();
Value output_state = operands[18];
auto output_state_type =
output_state.getType().dyn_cast_or_null<RankedTensorType>();
if (input_type && input_type.hasRank() && input_type.getRank() != 3) {
return failure();
}
if (output_state_type && output_state_type.hasRank() &&
output_state_type.getRank() != 2) {
return failure();
}
if (!input_type || !input_type.hasRank() || !output_state_type ||
!output_state_type.hasRank()) {
// We cannot infer the output shape since we don't know the input shape or
// the output state shape. We will set the output shape as unranked.
Type result_type;
result_type = UnrankedTensorType::get(
input.getType().cast<ShapedType>().getElementType());
inferredReturnTypes.assign({result_type});
return success();
}
// Default to non-time_major.
bool time_majored = attr.getNamed("time_major").hasValue()
? attr.getNamed("time_major")
.getValue()
.second.cast<BoolAttr>()
.getValue()
: false;
int batch =
time_majored ? input_type.getDimSize(1) : input_type.getDimSize(0);
int time = time_majored ? input_type.getDimSize(0) : input_type.getDimSize(1);
int n_output = output_state_type.getDimSize(1);
// Build the output shape.
SmallVector<int64_t, 3> output_shape;
if (time_majored) {
output_shape = {time, batch, n_output};
} else {
output_shape = {batch, time, n_output};
}
auto result_type =
mlir::RankedTensorType::get(output_shape, input_type.getElementType());
inferredReturnTypes.assign({result_type});
return success();
}
bool UnidirectionalSequenceLSTMOp::isCompatibleReturnTypes(ArrayRef<Type> lhs,
ArrayRef<Type> rhs) {
if (lhs.size() != rhs.size() || lhs.size() != 1) return false;
if (failed(mlir::verifyCompatibleShape(lhs[0], rhs[0]))) return false;
return true;
}
//===----------------------------------------------------------------------===//
// BidirectionalSequenceLSTMOp
//===----------------------------------------------------------------------===//

View File

@ -3959,7 +3959,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
TFL_OperandHasRank<15, 1>, // output_gate_bias
TFL_OperandIsNoneOrHasRank<16, 2>, // projection_weights
TFL_OperandIsNoneOrHasRank<17, 1>, // projection_bias
TFL_StatefulOp]> {
TFL_StatefulOp,
DeclareOpInterfaceMethods<InferTypeOpInterface>
]> {
let summary = "Unidirectional sequence lstm operator";
let description = [{
@ -4043,6 +4045,9 @@ def TFL_UnidirectionalSequenceLSTMOp :
let extraClassDeclaration = [{
// StatefulOpInterface:
std::vector<int> GetStatefulOperands() { return {18, 19}; }
// Compatiable return types check
static bool isCompatibleReturnTypes(ArrayRef<Type> l, ArrayRef<Type> r);
}];
}

View File

@ -459,3 +459,40 @@ func @tensorlistConcat(%arg0: tensor<?xf32>, %element_shape: tensor<0xi32>, %lea
// CHECK: %tensor, %lengths = "tf.TensorListConcatV2"(%0, %arg1, %arg2) : (tensor<!tf.variant<tensor<f32>>>, tensor<0xi32>, tensor<i64>) -> (tensor<?xf32>, tensor<0xi64>)
// CHECK: return %tensor, %lengths : tensor<?xf32>, tensor<0xi64>
}
// -----
func @whileLoopWithDynamicTensorList(%arg0: tensor<i32>, %arg1: tensor<i32>) -> tensor<*xf32> {
%cst = constant dense<3> : tensor<1xi32>
%cst_0 = constant dense<0> : tensor<i32>
%cst_1 = constant dense<-1> : tensor<i32>
%0 = "tf.TensorListReserve"(%arg0, %arg1) : (tensor<i32>, tensor<i32>) -> tensor<!tf.variant<tensor<?x?xf32>>>
%1:2 = "tf.While"(%cst_0, %0) {T = ["tfdtype$DT_INT32", "tfdtype$DT_VARIANT"], body = @tensorlistWhileBody, cond = @tensorlistWhileCond, is_stateless = false} : (tensor<i32>, tensor<!tf.variant<tensor<?x?xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
%2 = "tf.TensorListStack"(%1#1, %cst_1) : (tensor<!tf.variant<tensor<*xf32>>>, tensor<i32>) -> tensor<*xf32>
return %2 : tensor<*xf32>
// verify tensorlist ops pass through.
// CHECK-LABEL: func @whileLoopWithDynamicTensorList
// CHECK: "tf.TensorListReserve"
// CHECK: "tf.While"
// CHECK-SAME: (tensor<i32>, tensor<!tf.variant<tensor<?x?xf32>>>) -> (tensor<i32>, tensor<!tf.variant<tensor<*xf32>>>)
// CHECK: "tf.TensorListStack"
}
func @tensorlistWhileBody(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> (tensor<i32>, tensor<!tf.variant>) {
%0 = "tf.TensorListLength"(%arg1) : (tensor<!tf.variant>) -> tensor<i32>
%1 = "tf.Identity"(%arg1) : (tensor<!tf.variant>) -> tensor<!tf.variant>
return %0, %1 : tensor<i32>, tensor<!tf.variant>
// verify `body` function's signature stays unchanged.
// CHECK: func @tensorlistWhileBody(%[[ARG0:.*]]: tensor<i32>, %[[ARG:.*]]: tensor<!tf.variant>) -> (tensor<i32>, tensor<!tf.variant>)
}
func @tensorlistWhileCond(%arg0: tensor<i32>, %arg1: tensor<!tf.variant>) -> tensor<i1> {
%cst = constant dense<2> : tensor<i32>
%0 = "tf.Less"(%arg0, %cst) : (tensor<i32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
// verify `cond` function's signature stays unchanged.
// CHECK: func @tensorlistWhileCond(%[[ARG0:.*]]: tensor<i32>, %[[ARG1:.*]]: tensor<!tf.variant>) -> tensor<i1>
}

View File

@ -74,3 +74,14 @@ func @testConv2dShapeInvalidRanks(%arg0: tensor<1x112x80xf32>, %arg1: tensor<128
}
}
// -----
module attributes {tf.versions = {producer = 888 : i32}} {
// CHECK-LABEL: testUnidirectionalSequenceLstmShapeInference
func @testUnidirectionalSequenceLstmShapeInference(%arg0: tensor<600 x 10 x 20 x f32>, %arg1: tensor<? x ? x f32>, %arg2: tensor<? x ? x f32>, %arg3: tensor<? x ? x f32>, %arg4: tensor<? x ? x f32>, %arg5: tensor<? x ? x f32>, %arg6: tensor<? x ? x f32>, %arg7: tensor<? x ? x f32>, %arg8: tensor<? x ? x f32>, %arg9: tensor<? x f32>, %arg10: tensor<? x f32>, %arg11: tensor<? x f32>, %arg12: tensor<? x f32>, %arg13: tensor<? x f32>, %arg14: tensor<? x f32>, %arg15: tensor<40 x f32>, %arg16: tensor<? x ? x f32>, %arg17: tensor<? x f32>, %arg18: tensor<600 x 40 x f32>, %arg19: tensor<600 x 40 x f32>, %arg20: tensor<? x f32>, %arg21: tensor<? x f32>, %arg22: tensor<? x f32>, %arg23: tensor<? x f32>) -> tensor<? x ? x ? x f32> {
// CHECK: "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600x10x20xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<40xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<600x40xf32>, tensor<600x40xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<600x10x40xf32
%0 = "tfl.unidirectional_sequence_lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg18, %arg19, %arg20, %arg21, %arg22, %arg23) {fused_activation_function = "NONE", time_major = false} : (tensor<600 x 10 x 20 x f32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<40xf32>, tensor<?x?xf32>, tensor<?xf32>, tensor<600x40xf32>, tensor<600x40xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>, tensor<?xf32>) -> tensor<? x ? x ? xf32>
return %0 : tensor<? x ? x ? x f32>
}
}

View File

@ -860,7 +860,8 @@ llvm::SmallSet<int, 4> GetTensorListArgumentsFromWhileOp(TF::WhileOp op) {
// Changes the function type of `cond_func` and `body_func` for the given While
// op.
LogicalResult UpdateFunctionTypes(TF::WhileOp op,
LogicalResult UpdateFunctionTypes(ConversionPatternRewriter &rewriter,
TF::WhileOp op,
llvm::SmallSet<int, 4> tensor_list_args) {
int func_index = 0;
for (FuncOp func : {op.cond_function(), op.body_function()}) {
@ -906,13 +907,18 @@ LogicalResult UpdateFunctionTypes(TF::WhileOp op,
// Change `func`'s argument type to `unranked_argument_types`. If it
// return types contain a `DT_VARIANT`, change it to the unranked type
// derived from the corresponding argument.
func.setType(FunctionType::get(op.getContext(), updated_argument_types,
updated_result_types));
// Change the argument type for the first block.
llvm::for_each(func.getArguments(), [&](BlockArgument &arg) {
arg.setType(updated_argument_types[arg.getArgNumber()]);
rewriter.updateRootInPlace(func, [&] {
func.setType(FunctionType::get(op.getContext(), updated_argument_types,
updated_result_types));
});
Region &entry = func.getRegion();
TypeConverter::SignatureConversion signature_conversion(
entry.getNumArguments());
for (auto arg : entry.getArguments()) {
signature_conversion.addInputs(
arg.getArgNumber(), updated_argument_types[arg.getArgNumber()]);
}
rewriter.applySignatureConversion(&entry, signature_conversion);
}
return success();
}
@ -944,7 +950,7 @@ struct ConvertWhile : public OpConversionPattern<TF::WhileOp> {
auto converted = rewriter.create<TF::WhileOp>(op.getLoc(), result_types,
operands, op.getAttrs());
converted.removeAttr("T");
(void)UpdateFunctionTypes(converted, tensor_list_args);
(void)UpdateFunctionTypes(rewriter, converted, tensor_list_args);
rewriter.replaceOp(op, converted.getResults());
return success();

View File

@ -162,18 +162,21 @@ std::string ExperimentalConvertSavedModelToMlir(
}
std::string ExperimentalConvertSavedModelV1ToMlirLite(
const std::string &saved_model_path, const std::string &tags,
bool upgrade_legacy, bool show_debug_info, TF_Status *status) {
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
TF_Status *status) {
std::unordered_set<string> tag_set =
absl::StrSplit(tags, ',', absl::SkipEmpty());
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
tensorflow::MLIRImportOptions import_options;
import_options.upgrade_legacy = upgrade_legacy;
auto module_or = SavedModelSignatureDefsToMlirImportLite(
saved_model_path, tag_set, /*exported_names=*/{}, &context,
import_options);
saved_model_path, tag_set, absl::Span<std::string>(exported_names),
&context, import_options);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";
@ -183,9 +186,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
}
std::string ExperimentalConvertSavedModelV1ToMlir(
const std::string &saved_model_path, const std::string &tags,
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
TF_Status *status) {
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool lift_variables, bool upgrade_legacy,
bool show_debug_info, TF_Status *status) {
// Load the saved model into a SavedModelBundle.
std::unordered_set<string> tag_set =
@ -200,12 +203,14 @@ std::string ExperimentalConvertSavedModelV1ToMlir(
}
// Convert the SavedModelBundle to an MLIR module.
std::vector<string> exported_names =
absl::StrSplit(exported_names_str, ',', absl::SkipEmpty());
mlir::MLIRContext context;
tensorflow::MLIRImportOptions import_options;
import_options.upgrade_legacy = upgrade_legacy;
auto module_or =
ConvertSavedModelV1ToMlir(bundle, {}, &context, import_options);
ConvertSavedModelV1ToMlir(bundle, absl::Span<std::string>(exported_names),
&context, import_options);
if (!module_or.status().ok()) {
Set_TF_Status_from_Status(status, module_or.status());
return "// error";

View File

@ -69,8 +69,9 @@ std::string ExperimentalConvertSavedModelToMlir(
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
std::string ExperimentalConvertSavedModelV1ToMlirLite(
const std::string &saved_model_path, const std::string &tags,
bool upgrade_legacy, bool show_debug_info, TF_Status *status);
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool upgrade_legacy, bool show_debug_info,
TF_Status *status);
// Load a SavedModel V1 and return a textual MLIR string corresponding to it.
//
@ -83,9 +84,9 @@ std::string ExperimentalConvertSavedModelV1ToMlirLite(
// Returns:
// A string of textual MLIR representing the raw imported SavedModel.
std::string ExperimentalConvertSavedModelV1ToMlir(
const std::string &saved_model_path, const std::string &tags,
bool lift_variables, bool upgrade_legacy, bool show_debug_info,
TF_Status *status);
const std::string &saved_model_path, const std::string &exported_names_str,
const std::string &tags, bool lift_variables, bool upgrade_legacy,
bool show_debug_info, TF_Status *status);
std::string ExperimentalRunPassPipeline(const std::string &mlir_txt,
const std::string &pass_pipeline,

View File

@ -102,11 +102,13 @@ def do_test(create_signature,
logging.info('Saved model to: %s', save_model_path)
# TODO(b/153507667): Set the following boolean flag once the hoisting
# variables logic from SavedModel importer is removed.
exported_names = ''
lift_variables = False
upgrade_legacy = True
if use_lite:
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir_lite(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
save_model_path, exported_names,
','.join([tf.saved_model.tag_constants.SERVING]),
upgrade_legacy, show_debug_info)
# We don't strictly need this, but it serves as a handy sanity check
# for that API, which is otherwise a bit annoying to test.
@ -116,7 +118,8 @@ def do_test(create_signature,
show_debug_info)
else:
mlir = pywrap_mlir.experimental_convert_saved_model_v1_to_mlir(
save_model_path, ','.join([tf.saved_model.tag_constants.SERVING]),
save_model_path, exported_names,
','.join([tf.saved_model.tag_constants.SERVING]),
lift_variables, upgrade_legacy, show_debug_info)
if canonicalize:

View File

@ -113,6 +113,8 @@ void CreateTPUBridgePipeline(OpPassManager &pm) {
// Run another shape inference pass because resource decomposition might have
// created new partial types.
pm.addPass(TF::CreateTFShapeInferencePass());
// Note that the region-based control-flow produced here still contains
// function call ops which get inlined by the subsequent inliner pass.
pm.addPass(TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addPass(CreateTPUClusterCleanupAttributesPass());

View File

@ -78,7 +78,11 @@ void AddSupportedOpsUsingFolding(MLIRContext* context,
OperationName(TF::ConcatOffsetOp::getOperationName(), context),
OperationName(TF::EmptyOp::getOperationName(), context),
OperationName(TF::ListDiffOp::getOperationName(), context),
OperationName(TF::RankOp::getOperationName(), context),
OperationName(TF::RangeOp::getOperationName(), context),
OperationName(TF::ShapeOp::getOperationName(), context),
OperationName(TF::ShapeNOp::getOperationName(), context),
OperationName(TF::SizeOp::getOperationName(), context),
};
supported_ops->insert(allowlist_ops.begin(), allowlist_ops.end());

View File

@ -279,7 +279,10 @@ void CreateConvertMlirToXlaHloPipeline(
mlir::OpPassManager& pm, llvm::StringRef device_type,
llvm::MutableArrayRef<std::unique_ptr<mlir::Pass>>
custom_legalization_passes) {
// Note that the region-based control-flow produced here still contains
// function call ops which get inlined by the subsequent inliner pass.
pm.addPass(mlir::TF::CreateTFFunctionalControlFlowToRegions());
pm.addPass(mlir::createInlinerPass());
pm.addNestedPass<mlir::FuncOp>(mlir::TF::CreateDropWhileShapeInvariantPass());
pm.addNestedPass<mlir::FuncOp>(mlir::createCanonicalizerPass());
// The SCCP pass performs constant propagation across the IR, which, for
@ -317,6 +320,7 @@ void CreateConvertMlirToXlaHloPipeline(
// inside PromoteResourcesToArgs.
pm.addPass(mlir::mhlo::createLegalizeTFControlFlowPass());
pm.addPass(mlir::mhlo::CreateLegalizeTfTypesPass());
pm.addNestedPass<mlir::FuncOp>(mlir::mhlo::createLegalizeTFPass(
/*allow_partial_conversion=*/true, /*legalize_chlo=*/true,
/*tf2xla_fallback_device_type=*/device_type));
@ -372,7 +376,7 @@ Status LegalizeToHlo(mlir::ModuleOp module_op, llvm::StringRef device_type,
if (failed(tf2xla.run(module_op))) {
return error_handler.Combine(
errors::InvalidArgument("TF to XLA legalization failed"));
errors::InvalidArgument("TF to XLA legalization failed: "));
}
if (VLOG_IS_ON(1))

View File

@ -29,7 +29,7 @@ package_group(
)
gentbl(
name = "xla_legalize_tf_inc_gen",
name = "legalize_tf_patterns_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
("-gen-rewriters", "transforms/generated_legalize_tf.inc"),
@ -48,6 +48,22 @@ gentbl(
],
)
gentbl(
name = "xla_legalize_tf_passes_inc_gen",
compatible_with = get_compatible_with_cloud(),
tbl_outs = [
("-gen-pass-decls -name XLA", "transforms/xla_legalize_tf_passes.h.inc"),
],
tblgen = "@llvm-project//mlir:mlir-tblgen",
td_file = "transforms/xla_legalize_tf_passes.td",
td_relative_includes = [
"../hlo/include",
],
td_srcs = [
"@llvm-project//mlir:PassBaseTdFiles",
],
)
gentbl(
name = "xla_passes_inc_gen",
compatible_with = get_compatible_with_cloud(),
@ -72,8 +88,8 @@ gentbl(
cc_library(
name = "xla_passes",
srcs = [
"transforms/passes_detail.h",
"transforms/prepare_for_export.cc",
"transforms/xla_passes_detail.h",
],
hdrs = [
"transforms/passes.h",
@ -96,19 +112,25 @@ cc_library(
"transforms/legalize_tf.cc",
"transforms/legalize_tf_communication.cc",
"transforms/legalize_tf_control_flow.cc",
"transforms/legalize_tf_types.cc",
"transforms/xla_legalize_tf_passes_detail.h",
],
hdrs = [
"transforms/passes.h",
],
deps = [
":attribute_importer",
":legalize_tf_patterns_inc_gen",
":type_to_shape",
":xla_legalize_tf_passes_inc_gen",
":xla_legalize_tf_with_tf2xla",
":xla_passes",
"//tensorflow/compiler/mlir/hlo",
"//tensorflow/compiler/mlir/hlo:chlo_legalize_to_hlo",
"//tensorflow/compiler/mlir/hlo:convert_op_folder",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
"//tensorflow/compiler/mlir/tensorflow:tensorflow_types",
"//tensorflow/compiler/xla:shape_util",
"//tensorflow/compiler/xla:xla_data_proto_cc",
"//tensorflow/compiler/xla/client:padding",

View File

@ -0,0 +1,54 @@
// RUN: tf-opt -xla-legalize-tf-types %s | FileCheck %s
func @relu_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
// CHECK: func @relu_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
// CHECK-NEXT: %[[X:.*]] = "tf.Relu"(%arg0) : (tensor<1xi8>) -> tensor<1xi8>
%0 = "tf.Relu"(%arg0) : (tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8>
return %0: tensor<1x!tf.qint8>
}
func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1x!tf.qint8>, %arg2: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
// CHECK: func @if_qint8(%arg0: tensor<i1>, %arg1: tensor<1xi8>, %arg2: tensor<1xi8>) -> tensor<1xi8>
// CHECK-NEXT: %0 = "tf.IfRegion"(%arg0) ( {
// CHECK-NEXT: "tf.Yield"(%arg1) : (tensor<1xi8>) -> ()
// CHECK-NEXT: }, {
// CHECK-NEXT: "tf.Yield"(%arg2) : (tensor<1xi8>) -> ()
// CHECK-NEXT: }) {is_stateless = false} : (tensor<i1>) -> tensor<1xi8>
// CHECK-NEXT: return %0 : tensor<1xi8>
%0 = "tf.IfRegion"(%arg0) ( {
"tf.Yield"(%arg1) : (tensor<1x!tf.qint8>) -> ()
}, {
"tf.Yield"(%arg2) : (tensor<1x!tf.qint8>) -> ()
}) {is_stateless = false} : (tensor<i1>) -> tensor<1x!tf.qint8>
return %0 : tensor<1x!tf.qint8>
}
func @id_qint8(%arg0: tensor<1x!tf.qint8>) -> tensor<1x!tf.qint8> {
// CHECK: func @id_qint8(%arg0: tensor<1xi8>) -> tensor<1xi8> {
// CHECK-NEXT: return %arg0 : tensor<1xi8>
return %arg0: tensor<1x!tf.qint8>
}
func @id_qint16(%arg0: tensor<1x!tf.qint16>) -> tensor<1x!tf.qint16> {
// CHECK: func @id_qint16(%arg0: tensor<1xi16>) -> tensor<1xi16> {
// CHECK-NEXT: return %arg0 : tensor<1xi16>
return %arg0: tensor<1x!tf.qint16>
}
func @id_qint32(%arg0: tensor<1x!tf.qint32>) -> tensor<1x!tf.qint32> {
// CHECK: func @id_qint32(%arg0: tensor<1xi32>) -> tensor<1xi32> {
// CHECK-NEXT: return %arg0 : tensor<1xi32>
return %arg0: tensor<1x!tf.qint32>
}
func @id_quint8(%arg0: tensor<1x!tf.quint8>) -> tensor<1x!tf.quint8> {
// CHECK: func @id_quint8(%arg0: tensor<1xui8>) -> tensor<1xui8> {
// CHECK-NEXT: return %arg0 : tensor<1xui8>
return %arg0: tensor<1x!tf.quint8>
}
func @id_quint16(%arg0: tensor<1x!tf.quint16>) -> tensor<1x!tf.quint16> {
// CHECK: func @id_quint16(%arg0: tensor<1xui16>) -> tensor<1xui16> {
// CHECK-NEXT: return %arg0 : tensor<1xui16>
return %arg0: tensor<1x!tf.quint16>
}

View File

@ -496,7 +496,7 @@ def : Pat<(TF_EluGradOp AnyStaticShapeTensor:$gradients, AnyRankedTensor:$featur
// TODO(hinsu): Lower unsigned and quantized types after supporting
// them in GetScalarOfType.
def : Pat<(TF_ReluOp AnyRankedTensor:$input),
def : Pat<(TF_ReluOp AnyTensor:$input),
(HLOClient_BroadcastMaxOp
(HLO_ConstOp:$zero (GetScalarOfType<0> $input)), $input,
(BinBroadcastDimensions $zero, $input)),

View File

@ -0,0 +1,178 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// The TF dialect uses some TF types that are illegal in the MHLO dialect and
// some generic types that are legal in MHLO. This pass legalizes TF types into
// types that are legal in MHLO. For example, TF::Qint8Type is converted to i8.
// Rewrites here should run before TF to MHLO op legalizations are run.
// TODO(b/180234029): The rewrite here should be part of the LegalizeTF pass
// rather than its own pass.
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "mlir/IR/BuiltinOps.h" // from @llvm-project
#include "mlir/IR/BuiltinTypes.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Support/LLVM.h" // from @llvm-project
#include "mlir/Support/LogicalResult.h" // from @llvm-project
#include "mlir/Transforms/DialectConversion.h" // from @llvm-project
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_types.h"
#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes_detail.h"
#define DEBUG_TYPE "xla-legalize-tf-types"
namespace mlir {
namespace mhlo {
namespace {
bool IsIllegalElementType(Type type) {
return type
.isa<mlir::TF::Qint8Type, mlir::TF::Qint16Type, mlir::TF::Qint32Type,
mlir::TF::Quint8Type, mlir::TF::Quint16Type>();
}
Type ToLegalElementType(Type type) {
return TypeSwitch<Type, Type>(type)
.Case<mlir::TF::Qint8Type>([&type](Type) {
return mlir::IntegerType::get(type.getContext(), 8);
})
.Case<mlir::TF::Qint16Type>([&type](Type) {
return mlir::IntegerType::get(type.getContext(), 16);
})
.Case<mlir::TF::Qint32Type>([&type](Type) {
return mlir::IntegerType::get(type.getContext(), 32);
})
.Case<mlir::TF::Quint8Type>([&type](Type) {
return mlir::IntegerType::get(
type.getContext(), 8,
mlir::IntegerType::SignednessSemantics::Unsigned);
})
.Case<mlir::TF::Quint16Type>([&type](Type) {
return mlir::IntegerType::get(
type.getContext(), 16,
mlir::IntegerType::SignednessSemantics::Unsigned);
})
.Default([&type](Type) { return type; });
}
// TODO(b/180234863): What's below this line is generic so convert it to a
// utility.
bool IsIllegalType(Type type) {
return IsIllegalElementType(getElementTypeOrSelf(type));
}
Type ToLegalType(Type type) {
if (IsIllegalElementType(type)) return ToLegalElementType(type);
if (auto shaped = type.dyn_cast<ShapedType>()) {
Type elem = shaped.getElementType();
if (IsIllegalType(elem)) return shaped.clone(ToLegalType(elem));
}
return type;
}
class TfTypeConverter : public TypeConverter {
public:
TfTypeConverter() {
addConversion([](Type type) -> Type {
return IsIllegalType(type) ? ToLegalType(type) : type;
});
}
};
// An Op is illegal iff it contains an illegalType.
class TfTypeConversionTarget : public ConversionTarget {
public:
explicit TfTypeConversionTarget(MLIRContext &ctx, TfTypeConverter &converter)
: ConversionTarget(ctx), converter_(converter) {
markUnknownOpDynamicallyLegal();
}
protected:
bool isDynamicallyLegal(Operation *op) const override {
// The FuncOp type can contain types that the op's operand and result types
// do not contain.
if (auto func = dyn_cast<FuncOp>(op)) {
if (!converter_.isSignatureLegal(func.getType())) return false;
}
return converter_.isLegal(op);
}
private:
TfTypeConverter &converter_;
};
class TfTypePattern : public ConversionPattern {
public:
TfTypePattern(MLIRContext *ctx, TypeConverter &converter)
: ConversionPattern(1, converter, MatchAnyOpTypeTag()) {}
// The dialect conversion framework will call this matchAndRewrite on each
// Operation in the IR tree. This call matchAndRewrite needs to update the
// Operation's results and child regions.
LogicalResult matchAndRewrite(
Operation *op, ArrayRef<Value> operands,
ConversionPatternRewriter &rewriter) const override {
// Update the results.
llvm::SmallVector<Type, 4> new_results;
if (failed(getTypeConverter()->convertTypes(op->getResultTypes(),
new_results)))
return failure();
// Update the regions. The dialect conversion framework wants new regions to
// be created and updated, rather than updating the old op. Thus we use an
// OperationState so we can add regions to the new up.
OperationState state(op->getLoc(), op->getName().getStringRef(), operands,
new_results, op->getAttrs(), op->getSuccessors());
for (Region &region : op->getRegions()) {
Region &new_region = *state.addRegion();
rewriter.inlineRegionBefore(region, new_region, new_region.begin());
if (failed(rewriter.convertRegionTypes(&new_region, *getTypeConverter())))
return failure();
}
rewriter.replaceOp(op, rewriter.createOperation(state)->getResults());
return success();
}
};
struct LegalizeTfTypesPass
: public LegalizeTfTypesPassBase<LegalizeTfTypesPass> {
void runOnOperation() override;
};
void LegalizeTfTypesPass::runOnOperation() {
TfTypeConverter converter;
OwningRewritePatternList patterns;
patterns.insert<TfTypePattern>(&getContext(), converter);
populateFuncOpTypeConversionPattern(patterns, &getContext(), converter);
TfTypeConversionTarget target(getContext(), converter);
if (failed(applyFullConversion(getOperation(), target, std::move(patterns))))
return signalPassFailure();
}
static PassRegistration<LegalizeTfTypesPass> registration(
"xla-legalize-tf-types",
"Replace TensorFlow types with types that are legal in the MHLO dialect");
} // namespace
std::unique_ptr<OperationPass<>> CreateLegalizeTfTypesPass() {
return std::make_unique<LegalizeTfTypesPass>();
}
} // namespace mhlo
} // namespace mlir

View File

@ -49,6 +49,10 @@ std::unique_ptr<OperationPass<FuncOp>> createLegalizeTFPass(
std::unique_ptr<OperationPass<FuncOp>> createLegalizeTfWithTf2XlaPass(
llvm::StringRef device_type);
/// Replaces types that do not exist in MHLO with equivalent types that do
/// exist.
std::unique_ptr<OperationPass<ModuleOp>> CreateLegalizeTfTypesPass();
/// Adds the TF to XLA via TF2XLA rewrite patterns to the pattern list.
void PopulateLegalizeTfWithTf2XlaPatterns(llvm::StringRef device_type,
OwningRewritePatternList& patterns);

View File

@ -26,7 +26,7 @@ limitations under the License.
#include "mlir/IR/Types.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/hlo/include/mlir-hlo/Dialect/mhlo/IR/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes_detail.h"
#include "tensorflow/compiler/mlir/xla/transforms/xla_passes_detail.h"
#define DEBUG_TYPE "xla-prepare-for-export"

View File

@ -0,0 +1,36 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Declare passes used in xla_legalize_tf.
include "mlir/Pass/PassBase.td"
def LegalizeTfTypesPass : Pass<"xla-legalize-tf-types"> {
let summary = "Replace TensorFlow types with types that are legal in the MHLO dialect";
let description = [{
The TF dialect uses some TF types that are illegal in the MHLO dialect and
some generic types that are legal in MHLO. This pass legalizes TF types into
types that are legal in MHLO. Rewrites here should run before TF to MHLO op
legalizations are run.
Specifically, this pass replaces each quantized integer type with the
corresponding ordinary types. For example, `TF::Qint8Type` is replaced with `i8`
everywhere it occurs. Types that are replaced are `TF::Qint8Type`,
`TF::Qint16Type`, `TF::Qint32Type`, `TF::Quint8Type`, and `TF::Quint16Type`.
}];
let constructor = "::mlir::mhlo::CreateLegalizeTfTypesPass()";
}

View File

@ -0,0 +1,31 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TF_PASSES_DETAIL_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TF_PASSES_DETAIL_H_
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace mhlo {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/xla/transforms/xla_legalize_tf_passes.h.inc"
} // namespace mhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_LEGALIZE_TF_PASSES_DETAIL_H_

View File

@ -0,0 +1,31 @@
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_DETAIL_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_DETAIL_H_
#include "mlir/IR/Dialect.h"
#include "mlir/Pass/Pass.h"
namespace mlir {
namespace mhlo {
#define GEN_PASS_CLASSES
#include "tensorflow/compiler/mlir/xla/transforms/xla_passes.h.inc"
} // namespace mhlo
} // namespace mlir
#endif // TENSORFLOW_COMPILER_MLIR_XLA_TRANSFORMS_XLA_PASSES_DETAIL_H_

View File

@ -467,6 +467,80 @@ Status GetTrtBroadcastShape(const TRT_TensorOrWeights& operand_l,
return Status::OK();
}
// Prepares a dynamic shape tensor for broadcast by adding leading 1 dimensions.
Status DynamicBroadcast(nvinfer1::ITensor* operand, OpConverterParams* params,
nvinfer1::ITensor** output, int broadcasted_nbDims) {
int operand_nbDims = operand->getDimensions().nbDims;
if (broadcasted_nbDims > operand_nbDims) {
if (params->validation_only) return Status::OK();
int n_extra_dims = broadcasted_nbDims - operand_nbDims;
VLOG(2) << "Dynamic broadcast adding " << n_extra_dims << " leading 1s";
TF_RETURN_IF_ERROR(params->converter->DynamicReshape(
operand, {std::make_pair(0, operand_nbDims)}, params, output,
{n_extra_dims}));
} else {
*output = operand;
}
return Status::OK();
}
Status BroadcastWeights(std::unique_ptr<TRT_TensorOrWeights>& p,
nvinfer1::Dims broadcasted_dims) {
if (!p->is_weights()) return errors::Internal("Weight input expected");
if (p->GetTrtDims().nbDims != broadcasted_dims.nbDims) {
TRT_ShapedWeights weights(p->weights());
TF_RETURN_IF_ERROR(weights.SetShape(broadcasted_dims));
p = std::make_unique<TRT_TensorOrWeights>(weights);
}
return Status::OK();
}
Status ApplyBroadcast(std::unique_ptr<TRT_TensorOrWeights>& operand,
nvinfer1::Dims broadcasted_dims,
OpConverterParams* params) {
if (operand->is_weights()) {
TF_RETURN_IF_ERROR(BroadcastWeights(operand, broadcasted_dims));
} else {
nvinfer1::ITensor* tensor = nullptr;
auto is_static_shuffle_compatible = [](nvinfer1::Dims dims) {
return std::count(dims.d, dims.d + dims.nbDims, -1) <= 1;
};
if (is_static_shuffle_compatible(broadcasted_dims)) {
TF_RETURN_IF_ERROR(PrepareTensorForShape(
params->converter, *operand, broadcasted_dims,
params->validation_only, &tensor, params->node_def));
} else {
TF_RETURN_IF_ERROR(DynamicBroadcast(operand->tensor(), params, &tensor,
broadcasted_dims.nbDims));
}
operand = std::make_unique<TRT_TensorOrWeights>(tensor);
}
return Status::OK();
}
// Inserts leading 1 dimensions so that both operands have the same rank.
// Note: In implicit batch mode, weights' shape can include an explicit 1 batch
// dimension. The broadcasted shape might loose this leading batch dim, because
// the broadcasted shape does not include the implicit batch dim.
// TODO(tfeher): Other code blocks that use GetTrtBroadcastShape need to be
// fixed to use this routine to handle dynamic inputs. Eventually,
// GetTrtBroadcastShape should only be used by this routine.
Status BroadcastTensors(std::unique_ptr<TRT_TensorOrWeights>& operand_l,
std::unique_ptr<TRT_TensorOrWeights>& operand_r,
bool check_feasibility, OpConverterParams* params) {
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
*operand_l, *operand_r, check_feasibility, params->use_implicit_batch,
&broadcasted_dims_l, &broadcasted_dims_r));
if (params->validation_only) return Status::OK();
TF_RETURN_IF_ERROR(ApplyBroadcast(operand_l, broadcasted_dims_l, params));
TF_RETURN_IF_ERROR(ApplyBroadcast(operand_r, broadcasted_dims_r, params));
return Status::OK();
}
nvinfer1::ITensor* Converter::CreateConstantLayer(
const TRT_ShapedWeights& weights, const nvinfer1::Dims& dims) {
nvinfer1::Weights trt_weights = weights.GetTrtWeights();
@ -731,6 +805,17 @@ Status TRT_ShapedWeights::SetValues(T value) {
return Status::OK();
}
Status TRT_ShapedWeights::SetShape(nvinfer1::Dims dims) {
if (this->count() != TrtWeightDimsNumElements(dims)) {
VLOG(2) << "Changing shape from "
<< tensorflow::tensorrt::DebugString(shape_) << ", to "
<< tensorflow::tensorrt::DebugString(dims);
return errors::Internal("SetShape would change number of elements");
}
shape_ = dims;
return Status::OK();
}
size_t TRT_ShapedWeights::size_bytes() const {
size_t data_type_size = -1;
switch (type_) {
@ -1707,14 +1792,18 @@ Status PrepareTensorForShape(Converter* converter,
const NodeDef& node_def,
absl::optional<int> op_instance) {
const nvinfer1::Dims input_dims = input.GetTrtDims();
// If one of input_dims and dims doesn't have static shape, it means some of
// the dims are unknown or need to be inferred. And we don't do further checks
// but rely on the caller to not make mistakes.
// Otherwise we do simple check to make sure the total sizes are the same.
// The input shape may have -1s for dynamic shape. The target shape may have
// 0s representing copy over the corresponding input dimensions. It may also
// have at most one -1 representing a dimension value that needs to be
// inferred. If none of those special values present, we verify that the total
// sizes of the input and output shape are the same.
// TODO(tfeher): Verify that the total sizes of the input and output shape are
// the same in the present of 0s but no -1 in the target shape.
// If an input is a weight, it is going to become a tensor via
// CreateConstantLayer. So we can treat it as a tensor for
// AreDimsStaticWithDifferentSize(). This really only matters for 0-D tensors.
if (AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
if (Prod(dims) > 0 &&
AreDimsStaticWithDifferentSize(input_dims, dims, /*is_tensor=*/true)) {
return errors::InvalidArgument(
"Incompatible shapes: ", DebugString(input_dims), " vs. ",
DebugString(dims));
@ -2700,8 +2789,13 @@ Status Converter::DynamicReshape(nvinfer1::ITensor* input,
int slice_instance = i * max_num_slices + op_instance_value;
// maybe_add_a_dimension(i);
if (i < size_for_added_dims.size() && size_for_added_dims[i] >= 0) {
nvinfer1::Dims dims{1, {1}};
if (size_for_added_dims[i] > 0) {
dims.d[0] = size_for_added_dims[i];
}
TF_RETURN_IF_ERROR(
CreateScalarConstant(params, size_for_added_dims[i], &tensor));
CreateScalarConstant(params, std::min(size_for_added_dims[i], 1),
&tensor, nvinfer1::DataType::kINT32, dims));
concat_inputs.push_back(tensor);
}
if (i < slices.size()) {
@ -5422,30 +5516,86 @@ Status ConvertGather(OpConverterParams* params) {
return Status::OK();
}
// Converts the input matrix multiplication node to a fully connected (FC) layer
// if possible, as the FC layer has more tactics and INT implementations. Set
// *converted true if the node is converted. Otherwise, *converted==false and
// Status::OK indicates the node can't be converted to an FC layer. An error
// Status indicates internal problems during conversion.
Status ConvertFullyConnectedHelper(OpConverterParams* params,
nvinfer1::ITensor* tensor_a,
TRT_ShapedWeights weights_b,
bool transpose_b) {
// Reshape input to 3D - this will be a no-op unless using int8 precision.
auto input_dim = tensor_a->getDimensions();
while (input_dim.nbDims < 3) {
input_dim.d[input_dim.nbDims++] = 1;
TRT_TensorOrWeights input_a,
TRT_TensorOrWeights input_b,
bool transpose_a, bool transpose_b,
bool* converted) {
*converted = false;
if (!(!transpose_a && input_a.is_tensor() && input_b.is_weights())) {
VLOG(2) << "Not FC compatible, A must be non transposed tensor, and B "
"must be constant.";
return Status::OK();
}
const auto& node_def = params->node_def;
if (!params->use_implicit_batch && input_b.GetTrtDims().nbDims > 2 &&
input_b.GetTrtDims().d[0] != 1) {
// Implicit broadcasting, if needed, has already been considered to
// transform the inputs and ensure the two operands have the same rank here.
// If the inputs have rank >= 3, then d[0] is the explicit batch dimension.
// The weight (input_b) must have batch size 1 in implicit batch mode.
VLOG(2) << "Not FC compatible, if B has an explicit batch dimension, then "
"it must be 1.";
return Status::OK();
}
nvinfer1::Dims input_dim = input_a.GetTrtDims();
if (input_dim.d[input_dim.nbDims - 1] == -1) {
VLOG(2) << "Not FC compatible, last dim of A must be static.";
return Status::OK();
}
if (input_dim.nbDims + 2 > nvinfer1::Dims::MAX_DIMS) {
VLOG(2) << "Not FC compatible, cannot expand A's shape.";
return Status::OK();
}
// Add two trailing 1's because FC layer combines the last three dims.
nvinfer1::ITensor* tensor_a = nullptr;
nvinfer1::Dims reshape_dim{input_dim.nbDims + 2, {}};
// The empty braces initialize the elements of reshap_dim.d to 0. A value 0 in
// reshape_dim.d[i] will preserve the i-th dimension value from the shape of
// input_a.
reshape_dim.d[input_dim.nbDims] = 1;
reshape_dim.d[input_dim.nbDims + 1] = 1;
const NodeDef& node_def = params->node_def;
TF_RETURN_IF_ERROR(PrepareTensorForShape(
params->converter, TRT_TensorOrWeights(tensor_a), input_dim,
params->converter, input_a, reshape_dim,
/*validation_only=*/false, &tensor_a, node_def, /*op_instance=*/0));
VLOG(2) << "New shape of A " << DebugString(tensor_a->getDimensions());
TRT_ShapedWeights weights_b = input_b.weights();
TRT_ShapedWeights weights_2D(weights_b);
if (weights_b.shape_.nbDims > 2) {
// Combine first nbDims-1 dims into a single dim, e.g. for a 4D tensor we
// transform [N, H, W, C] -> [N*H*W, C].
int k = weights_b.shape_.d[weights_b.shape_.nbDims - 1];
nvinfer1::Dims dims{2, {static_cast<int>(weights_b.count() / k), k}};
TF_RETURN_IF_ERROR(weights_2D.SetShape(dims));
}
// FC layer will transpose weights, so we need to pre-transpose.
TRT_ShapedWeights weights(weights_b.TrtDType());
TRT_ShapedWeights weights(weights_2D.TrtDType());
if (!transpose_b) {
weights = params->weight_store->GetTempWeights(weights_b);
ReorderCKtoKC(weights_b, &weights);
weights = params->weight_store->GetTempWeights(weights_2D);
ReorderCKtoKC(weights_2D, &weights);
} else {
weights = weights_b;
weights = weights_2D;
}
TRT_ShapedWeights biases(weights.TrtDType());
const int noutput = weights.shape_.d[0];
int k = weights.shape_.d[weights.shape_.nbDims - 1];
const int noutput = weights.count() / k;
VLOG(2) << "Using fully connected layer with k=" << k
<< ", n_output=" << noutput
<< " weights shape: " << DebugString(weights.shape_) << " to convert "
<< node_def.op();
nvinfer1::IFullyConnectedLayer* layer =
params->converter->network()->addFullyConnected(
*tensor_a, noutput, weights.GetTrtWeights(), biases.GetTrtWeights());
@ -5454,14 +5604,19 @@ Status ConvertFullyConnectedHelper(OpConverterParams* params,
params->converter->SetLayerName(layer, node_def);
nvinfer1::ITensor* output_tensor = layer->getOutput(0);
// Reshape output to 1D - this will be a no-op unless using int8 precision.
// A fully connected layer produces output with two trailing singleton
// dimensions. We remove these.
auto output_dim = output_tensor->getDimensions();
output_dim.nbDims = 1;
output_dim.nbDims -= 2;
// A zero in output_dim indicates copying the corresponding input dimension
// value during reshape.
std::fill(output_dim.d, output_dim.d + output_dim.nbDims, 0);
TF_RETURN_IF_ERROR(PrepareTensorForShape(
params->converter, TRT_TensorOrWeights(output_tensor), output_dim,
/*validation_only=*/false, &output_tensor, node_def, /*op_instance=*/1));
/*validation_only=*/false, &output_tensor, node_def,
/*op_instance=*/1));
params->outputs->push_back(TRT_TensorOrWeights(output_tensor));
*converted = true;
return Status::OK();
}
@ -5469,82 +5624,66 @@ Status ConvertMatMulHelper(OpConverterParams* params,
TRT_TensorOrWeights input_a,
TRT_TensorOrWeights input_b, bool transpose_a,
bool transpose_b) {
// TODO: ReorderCKtoKC is currently not general enough to transpose weights
// that are not 2D.
if ((transpose_a && input_a.is_weights() &&
input_a.GetTrtDims().nbDims != 2) ||
(transpose_b && input_b.is_weights() &&
input_b.GetTrtDims().nbDims != 2)) {
return errors::InvalidArgument(
"Cannot currently transpose constant input if it is not 2 dimensional");
if (params->use_implicit_batch) {
// In implicit batch mode we are very limited when can we multiply 2D
// matrices. If input_A is a 2D tensor, then nbDims==1 (implicit batch dim
// not counted). If A is not transposed and B is weight, then we can convert
// this treating A as a batch of vectors. This is the only possibility
// to implement MatMul with 2D input in implicit batch mode.
if ((input_a.GetTrtDims().nbDims < 2 &&
(transpose_a || !input_b.is_weights())) ||
(input_b.GetTrtDims().nbDims < 2)) {
return errors::InvalidArgument(
"MatMul with 2D tensors requires explicit batch mode, or that tensor"
" A is not transposed and B is a constant tensor.");
}
}
// If A is a tensor, we can only transpose if it is at least 3D in TF,
// or TRT will not do the correct transposition.
if (transpose_a && input_a.is_tensor() && input_a.GetTrtDims().nbDims < 2) {
return errors::InvalidArgument(
"Cannot transpose first input if it is a tensor with fewer than 2 "
"non-batch dimensions.");
}
// If B is a tensor, then it must be at least 3D in TF,
// or TRT won't be able to handle the multiply correctly.
if (input_b.is_tensor() && input_b.GetTrtDims().nbDims < 2) {
return errors::InvalidArgument(
"Second input must either be a constant, or contain at least 2 "
"non-batch dimensions.");
}
if (params->validation_only) return Status::OK();
// If an FC layer can be used and would be faster, use that instead.
const bool can_use_fc =
!transpose_a && input_a.is_tensor() && input_b.is_weights();
const bool should_use_fc = can_use_fc && input_a.GetTrtDims().nbDims >= 3 &&
input_b.GetTrtDims().nbDims == 2;
// If int8 is specified, FC must be used unless it is not compatible, as MM
// does not support int8 at this time.
if (should_use_fc || (can_use_fc && params->converter->precision_mode() ==
TrtPrecisionMode::INT8)) {
return ConvertFullyConnectedHelper(params, input_a.tensor(),
input_b.weights(), transpose_b);
}
bool converted = false;
TF_RETURN_IF_ERROR(ConvertFullyConnectedHelper(
params, input_a, input_b, transpose_a, transpose_b, &converted));
if (converted == true) return Status::OK();
const auto get_matrix_op = [](nvinfer1::ITensor* in,
bool transpose) -> nvinfer1::MatrixOperation {
return (in->getDimensions().nbDims < 2) ? nvinfer1::MatrixOperation::kVECTOR
: (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
: nvinfer1::MatrixOperation::kNONE;
};
// If the MatMul operand is a constant, applies transposes at conversion-time
// as necessary. If the operand is a tensor, does nothing. If required
// transposes were applied, sets transpose to false.
const auto prepare_matmul_operand =
[&params](TRT_TensorOrWeights operand,
bool* transpose) -> nvinfer1::ITensor* {
const auto convert_to_itensor =
[&params](TRT_TensorOrWeights operand) -> nvinfer1::ITensor* {
if (operand.is_tensor()) {
return operand.tensor();
} else {
TRT_ShapedWeights weights(operand.weights().TrtDType());
if (*transpose) {
weights = params->weight_store->GetTempWeights(operand.weights());
ReorderCKtoKC(operand.weights(), &weights);
// Weights have been transposed, can set transpose to false
*transpose = false;
} else {
weights = operand.weights();
}
return params->converter->CreateConstantLayer(weights, weights.shape_);
return params->converter->CreateConstantLayer(operand.weights(),
operand.GetTrtDims());
}
};
nvinfer1::ITensor* tensor_a = prepare_matmul_operand(input_a, &transpose_a);
nvinfer1::ITensor* tensor_b = prepare_matmul_operand(input_b, &transpose_b);
nvinfer1::ITensor* tensor_a = convert_to_itensor(input_a);
nvinfer1::ITensor* tensor_b = convert_to_itensor(input_b);
const auto get_matrix_op = [](nvinfer1::ITensor* in,
bool transpose) -> nvinfer1::MatrixOperation {
return (transpose) ? nvinfer1::MatrixOperation::kTRANSPOSE
: nvinfer1::MatrixOperation::kNONE;
};
nvinfer1::MatrixOperation op_a, op_b;
// Note: In implicit batch mode kTRANSPOSE and kNONE are only valid if the
// matrix has at least 2 non-batch dimension. In implicit batch mode, if a has
// 1 dim (excluding batch dim), then we can only use kVECTOR, which will treat
// matrix A as a batch of vectors.
op_a = (tensor_a->getDimensions().nbDims < 2)
? nvinfer1::MatrixOperation::kVECTOR
: get_matrix_op(tensor_a, transpose_a);
// In implicit batch mode, if B has only 1 dims (excluding batch dim) then we
// already reject the case and don't convert. One could consider using the
// kVECTOR flag to express C = MatMul(A, B.T) if A is weight, but the result
// will not have the correct shape: in TRT's implicit batch implementation,
// the result is a batch of vectors D_ji = A_ik * B_jk, where j is the batch
// dimension. In contrast, the TF MatMul op produces C = D.T, and we cannot
// transpose over the batch dimension (implicit batch mode).
op_b = get_matrix_op(tensor_b, transpose_b);
nvinfer1::IMatrixMultiplyLayer* layer =
params->converter->network()->addMatrixMultiply(
*tensor_a, get_matrix_op(tensor_a, transpose_a), *tensor_b,
get_matrix_op(tensor_b, transpose_b));
params->converter->network()->addMatrixMultiply(*tensor_a, op_a,
*tensor_b, op_b);
const auto& node_def = params->node_def;
TFTRT_RETURN_ERROR_IF_NULLPTR(layer, node_def.name());
@ -5582,44 +5721,45 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
" inputs but expected 2, at ",
node_def.name());
}
// TODO(tmorris): Enable once false is updated to mean either tensor or weight
// TF_RETURN_IF_ERROR(CheckInputsWeights(*params, {{"x", false}, {"y",
// false}}));
TF_RETURN_IF_ERROR(CheckInputsWeights(
*params, {{"x", TrtInputArg::kBoth}, {"y", TrtInputArg::kBoth}}));
// TODO(tfeher): Consider adding INT8 type because FC layer can support it.
TF_RETURN_IF_ERROR(
AllowDataTypes(*params, {DataType::DT_FLOAT, DataType::DT_HALF}));
if (inputs.at(0).is_weights() && inputs.at(1).is_weights()) {
return errors::InvalidArgument(
"All inputs are weights, but Grappler is expected to fold them.");
}
if (inputs.at(0).is_tensor() && inputs.at(1).is_tensor() &&
inputs.at(0).GetTrtDims().nbDims != inputs.at(1).GetTrtDims().nbDims) {
return errors::Unimplemented(
"Inputs must have the same rank if they are both tensors.");
}
TFAttrs attrs(node_def);
const bool transpose_a = attrs.get<bool>("adj_x");
const bool transpose_b = attrs.get<bool>("adj_y");
// There is no way to batch constants in TRT. Example:
// Tensor with TF Dims: 12 5 3 -> TRT Dims: 5 3
// Weight with TF Dims: 12 3 6 -> TRT Dims: 12 3 6
// It is not possible to treat the weight input as a batched [3, 6] tensor.
// In case input_l is weight, check whether input_l has implicit batch mode
// compatible batch dim.
const auto check_weight_is_not_batched =
[](const TRT_TensorOrWeights& input_l,
const TRT_TensorOrWeights& input_r) {
// If input_l is a weight, then input_r must be a tensor because
// otherwise the op would be handled by Grappler.
// There is no way to batch constants in TRT using implicit batch mode.
// Example:
// Tensor with TF Dims: 12 5 3 -> TRT Dims: 5 3
// Weight with TF Dims: 12 3 6 -> TRT Dims: 12 3 6
// It is not possible to treat the weight input as a batched [3, 6]
// tensor. Batched weight tensors must have batch dim = 1 (after the
// broadcast).
if (input_l.is_weights() &&
input_l.GetTrtDims().nbDims > input_r.GetTrtDims().nbDims &&
input_l.GetTrtDims().d[0] != 1) {
return errors::Unimplemented(
"TensorRT does not support batched constants.");
"TensorRT does not support batched constants in implicit batch "
"mode.");
}
return Status::OK();
};
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(0), inputs.at(1)));
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(1), inputs.at(0)));
if (params->use_implicit_batch) {
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(0), inputs.at(1)));
TF_RETURN_IF_ERROR(check_weight_is_not_batched(inputs.at(1), inputs.at(0)));
}
// Broadcast inputs. We don't check feasibility since the dimensions in a
// MatMul don't need to match. For example, consider a valid set of inputs
@ -5627,22 +5767,14 @@ Status ConvertBatchMatMul(OpConverterParams* params) {
// input 0: [N, T, C]
// input 1: [1, C, K]
// Since C != K and T != C, check feasiblity would fail.
nvinfer1::Dims broadcasted_dims_l, broadcasted_dims_r;
TF_RETURN_IF_ERROR(GetTrtBroadcastShape(
inputs.at(0), inputs.at(1), /*check_feasibility=*/false,
params->use_implicit_batch, &broadcasted_dims_l, &broadcasted_dims_r));
nvinfer1::ITensor* tensor_l = nullptr;
nvinfer1::ITensor* tensor_r = nullptr;
TF_RETURN_IF_ERROR(
PrepareTensorForShape(params->converter, inputs.at(0), broadcasted_dims_l,
params->validation_only, &tensor_l, node_def));
TF_RETURN_IF_ERROR(
PrepareTensorForShape(params->converter, inputs.at(1), broadcasted_dims_r,
params->validation_only, &tensor_r, node_def));
auto input_l = std::make_unique<TRT_TensorOrWeights>(inputs.at(0));
auto input_r = std::make_unique<TRT_TensorOrWeights>(inputs.at(1));
TF_RETURN_IF_ERROR(BroadcastTensors(input_l, input_r,
/*check_feasibility=*/false, params));
if (params->validation_only) return Status::OK();
return ConvertMatMulHelper(params, TRT_TensorOrWeights(tensor_l),
TRT_TensorOrWeights(tensor_r), transpose_a,
return ConvertMatMulHelper(params, *input_l, *input_r, transpose_a,
transpose_b);
}

View File

@ -192,6 +192,8 @@ class TRT_ShapedWeights {
template <typename T>
Status SetValues(T value);
Status SetShape(nvinfer1::Dims dims);
int64_t count() const;
size_t size_bytes() const;
@ -558,10 +560,12 @@ class Converter {
// This can be achieved by calling DynamicReshape(input, {{2,4},{0,2}},
// params).
//
// Before each slice we can insert a new dim if the corresponding
// Before each slice we can insert new dims if the corresponding
// size_for_added_dims element is not negative. The size_for_added_dims array
// can have more than slices.size() elements, in order to insert a dimension
// ater the last slice.
// after the last slice. For example, to add two leading 1 dimensions, and
// three trailing 1 dimensions, call DynamicReshape(input, {{0,nbDims}},
// {2, 3}).
//
// Parameters:
// input - input tensor

View File

@ -712,8 +712,8 @@ TEST(TrtNodeValidator, IsTensorRTCandidate) {
ExpectStatus(
validator.IsTensorRTCandidate(incompatible_matmul.operation.node()),
error::INVALID_ARGUMENT,
"Cannot transpose first input if it is a tensor with fewer than 2 "
"non-batch dimensions.");
"MatMul with 2D tensors requires explicit batch mode, or that tensor A "
"is not transposed and B is a constant tensor.");
ExpectStatus(validator.IsTensorRTCandidate(unsupported_op.operation.node()),
error::UNIMPLEMENTED, "Op type Erf is not supported");
ExpectStatus(validator.IsTensorRTCandidate(
@ -2466,87 +2466,138 @@ TEST_P(OpConverter_FP32_Test, ConvertShape) {
}
}
// Helper function for testing MatMul and BatchMatMul
// get_matmul corresponds to the function used to generate the node. It should
// accept (DataType, transpose_a, transpose_b) as parameters.
struct MatMulTestParams {
std::vector<int> shape_a;
std::vector<int> values_a;
bool transpose_a;
std::vector<int> shape_b;
std::vector<int> values_b;
bool transpose_b;
std::vector<int> expected_shape;
std::vector<int> expected_output;
};
// Helper function for testing MatMul and BatchMatMul. get_matmul is a function
// used to generate the node. It accepts (DataType, transpose_a, transpose_b) as
// parameters.
void TestMatMulHelper(
OpConverterTest* test,
ParameterizedOpConverterTestBase* test,
const std::function<NodeDef(DataType, bool, bool)>& get_matmul,
const std::string& op_name) {
// HACK: This needs to be done in a better way.
const bool is_batch_matmul = op_name == "BatchMatMul";
const std::vector<MatMulTestParams>& params) {
{
// Unsupported data type.
test->Reset();
NodeDef node_def = get_matmul(DT_INT32, false, false);
test->AddTestTensor("input", {2}, /*batch_size=*/1,
nvinfer1::DataType::kINT32);
test->AddTestTensor("input", {1, 2}, DT_INT32, {});
test->AddTestWeights<int32>("weights", {2, 1}, {3, 5});
test->RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
StrCat("Data type int32 is not supported for ", op_name,
StrCat("Data type int32 is not supported for ", node_def.op(),
", must be one of [float, half], at my_matmul")
.c_str());
}
// OK.
for (bool transpose_a : {false, true}) {
for (bool transpose_b : {false, true}) {
test->Reset();
NodeDef node_def = get_matmul(DT_FLOAT, transpose_a, transpose_b);
test->AddTestTensor("input", {2}, /*batch_size=*/1);
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
if (is_batch_matmul) {
test->RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"TensorRT does not support batched constants.");
continue;
} else if (transpose_a) {
test->RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Cannot transpose first input if it is a tensor with fewer than 2 "
"non-batch dimensions");
continue;
}
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test->AsTensor<float>({0, 1})}};
DataVec output_data{{"my_matmul", test->ConstructTensor<float>(2)}};
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
if (transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
} else {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(2, 3));
}
}
// FC conversion depends on whether the last dim of A is known or not. In
// Dynamic shape mode, we will check whether A is handled correctly if it has
// a partially known input shape (last dim known).
std::vector<bool> a_test_partial_shape_values{false};
if (test->get_trt_mode() == TrtTestMode::kDynamicShape) {
a_test_partial_shape_values.push_back(true);
}
// OK, 3D inputs
for (bool transpose_b : {false, true}) {
test->Reset();
NodeDef node_def = get_matmul(DT_FLOAT, /*transpose_a=*/false, transpose_b);
test->AddTestTensor("input", {2}, /*batch_size=*/1);
test->AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
if (is_batch_matmul) {
test->RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"TensorRT does not support batched constants.");
continue;
}
test->RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(test->GetTensorOrWeights("my_matmul", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", test->AsTensor<float>({0, 1})}};
DataVec output_data{{"my_matmul", test->ConstructTensor<float>(2)}};
TF_EXPECT_OK(test->BuildAndRun(input_data, &output_data));
if (transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(1, 3));
} else {
EXPECT_THAT(GetSpanForData<float>(output_data[0]), ElementsAre(2, 3));
for (auto p : params) {
for (bool a_is_tensor : {true, false}) {
for (bool b_is_tensor : {true, false}) {
for (bool a_partial_shape : a_test_partial_shape_values) {
if (a_partial_shape && !a_is_tensor) {
// Only tensors can have partial shape.
continue;
}
if (!a_is_tensor && !b_is_tensor) {
// Skip test when both args are weights. We do not convert this
// since const folding eliminates this case.
continue;
}
SCOPED_TRACE(StrCat("A", p.transpose_a ? ".T" : "", " is ",
a_is_tensor ? "tensor" : "weight", ", B",
p.transpose_b ? ".T" : "", " is ",
b_is_tensor ? "tensor " : "weight, rank A ",
p.shape_a.size(), ", rank B ", p.shape_b.size()));
test->Reset();
NodeDef node_def =
get_matmul(test->get_tf_type(), p.transpose_a, p.transpose_b);
const bool is_batch_matmul = node_def.op() == "BatchMatMul";
if (a_is_tensor) {
if (a_partial_shape) {
// Prepare a partial shape for A where only the last dim is known.
std::vector<int> partial_shape(p.shape_a.size(), -1);
int k = p.shape_a.size() - 1;
partial_shape.at(k) = p.shape_a.at(k);
test->AddTestTensor("input", p.shape_a, test->get_tf_type(),
p.values_a, partial_shape);
} else {
test->AddTestTensor("input", p.shape_a, p.values_a);
}
} else {
test->AddTestWeights("input", p.shape_a, p.values_a,
test->get_tf_type());
}
if (b_is_tensor) {
if (a_is_tensor && p.shape_a[0] != p.shape_b[0] &&
test->get_trt_mode() == TrtTestMode::kImplicitBatch) {
VLOG(2) << "Skipping test with inpcompatible batch dimensions";
continue;
}
test->AddTestTensor("weights", p.shape_b, p.values_b);
} else {
test->AddTestWeights("weights", p.shape_b, p.values_b,
test->get_tf_type());
}
Status conversion_status = Status::OK();
if (test->get_trt_mode() == TrtTestMode::kImplicitBatch) {
// Implicit batch mode has several restriction. We change conversion
// status accordingly.
if (is_batch_matmul) {
if (a_is_tensor && p.shape_a.size() < p.shape_b.size()) {
conversion_status = errors::InvalidArgument(
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims ",
p.shape_a.size(), " vs broadcast #dims ", p.shape_b.size(),
")");
}
if (b_is_tensor && p.shape_b.size() < p.shape_a.size()) {
conversion_status = errors::InvalidArgument(
"Broadcasting beyond batch dimension is not supported "
"(tensor #dims ",
p.shape_b.size(), " vs broadcast #dims ", p.shape_a.size(),
")");
}
if ((!a_is_tensor || !b_is_tensor) && p.shape_a[0] != 1) {
conversion_status = errors::Unimplemented(
"TensorRT does not support batched constants in implicit "
"batch mode.");
}
} else if ((a_is_tensor && p.shape_a.size() <= 2 &&
(p.transpose_a || b_is_tensor)) ||
(b_is_tensor && p.shape_b.size() <= 2)) {
conversion_status = errors::InvalidArgument(
"MatMul with 2D tensors requires explicit batch mode, or that"
" tensor A is not transposed and B is a constant tensor.");
}
}
test->TestOpConverter("my_matmul", node_def, p.expected_shape,
conversion_status, Status::OK(),
ElementsAreArray(p.expected_output));
if (!conversion_status.ok()) {
VLOG(2) << "Converted with status " << conversion_status;
}
VLOG(2) << "== Finished test iteration ==";
}
}
}
}
}
@ -2563,7 +2614,39 @@ void CheckAddedLayers(OpConverterTest* test, bool expect_found) {
EXPECT_EQ(expect_found, layer_found);
}
TEST_F(OpConverterTest, ConvertMatMul) {
std::vector<MatMulTestParams> GetMatMulTestParams() {
std::vector<MatMulTestParams> params{
// clang-format off
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, false, // A (shape, val, T?)
{2, 2}, {0, 1, 2, 3}, false, // B (shape, val, T?)
{2, 2}, {2, 3, 6, 11}}, // result (shape, val)
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, false,
{2, 2}, {0, 1, 2, 3}, true,
{2, 2}, {1, 3, 3, 13}},
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, true,
{2, 2}, {0, 1, 2, 3}, false,
{2, 2}, {4, 6, 6, 10}},
MatMulTestParams{{2, 2}, {0, 1, 2, 3}, true,
{2, 2}, {0, 1, 2, 3}, true,
{2, 2}, {2, 6, 3, 11}},
MatMulTestParams{{2, 3}, {0, 1, 2, 3, 4, 5}, false,
{2, 3}, {1, 2, 3, 4, 5, 6}, true,
{2, 2}, {8, 17, 26, 62}},
MatMulTestParams{{2, 3}, {0, 1, 2, 3, 4, 5}, true,
{2, 3}, {1, 2, 3, 4, 5, 6}, false,
{3, 3}, {12, 15, 18, 17, 22, 27, 22, 29, 36}},
MatMulTestParams{{3, 2}, {0, 1, 2, 3, 4, 5}, false,
{2, 3}, {1, 2, 3, 4, 5, 6}, false,
{3, 3}, {4, 5, 6, 14, 19, 24, 24, 33, 42}},
MatMulTestParams{{3, 2}, {0, 1, 2, 3, 4, 5}, true,
{2, 3}, {1, 2, 3, 4, 5, 6}, true,
{2, 2}, {16, 34, 22, 49}},
// clang-format on
};
return params;
}
TEST_P(OpConverter_FP32_Test, ConvertMatMul) {
// Get the NodeDef for MatMul.
auto get_matmul_nodedef = [](DataType dtype, bool transpose_a,
bool transpose_b) -> NodeDef {
@ -2577,64 +2660,10 @@ TEST_F(OpConverterTest, ConvertMatMul) {
return matmul.operation.node()->def();
};
// Additional test cases specific to MatMul
{
// Can only transpose A if it is 2D in TRT
Reset();
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, true, false);
AddTestTensor("input", {2}, /*batch_size=*/1);
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Cannot transpose first input if it is a tensor with fewer than 2 "
"non-batch dimensions.");
}
{
// B must always have 2 non-batch dimensions
Reset();
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false);
AddTestTensor("input", {2}, /*batch_size=*/1);
AddTestTensor("weights", {2}, /*batch_size=*/1);
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Second input must either be a constant, or contain at least 2 "
"non-batch dimensions.");
}
{
// We can never transpose weights that are not 2D.
Reset();
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, true, false);
AddTestWeights<float>("input", {1, 1, 2}, {0, 1});
AddTestTensor("weights", {2, 2}, /*batch_size=*/1);
RunValidationAndConversion(
node_def, error::INVALID_ARGUMENT,
"Cannot currently transpose constant input if it is not 2 dimensional");
}
{
// Make sure that INT8 mode uses IFullyConnectedLayer when possible.
Reset(TrtPrecisionMode::INT8);
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false);
AddTestTensor("input", {2, 1, 1});
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
RunValidationAndConversion(node_def);
CheckAddedLayers<nvinfer1::IMatrixMultiplyLayer>(this, false);
CheckAddedLayers<nvinfer1::IFullyConnectedLayer>(this, true);
}
{
// Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not
// compatible. In this case we can't use FC because weights is a tensor.
Reset(TrtPrecisionMode::INT8);
NodeDef node_def = get_matmul_nodedef(DT_FLOAT, false, false);
AddTestTensor("input", {2, 1, 1});
AddTestTensor("weights", {2, 2});
RunValidationAndConversion(node_def);
CheckAddedLayers<nvinfer1::IMatrixMultiplyLayer>(this, true);
CheckAddedLayers<nvinfer1::IFullyConnectedLayer>(this, false);
}
TestMatMulHelper(this, get_matmul_nodedef, "MatMul");
TestMatMulHelper(this, get_matmul_nodedef, GetMatMulTestParams());
}
TEST_F(OpConverterTest, ConvertBatchMatMul) {
TEST_P(OpConverter_FP32_Test, ConvertBatchMatMul) {
// Get the NodeDef for BatchMatMul.
auto get_batch_matmul_nodedef = [](DataType dtype, bool transpose_a,
bool transpose_b) -> NodeDef {
@ -2648,61 +2677,93 @@ TEST_F(OpConverterTest, ConvertBatchMatMul) {
return matmul.operation.node()->def();
};
{
// Can't broadcast two tensor inputs of different rank.
Reset();
NodeDef node_def = get_batch_matmul_nodedef(DT_FLOAT, false, false);
AddTestTensor("input", {1, 2, 2}, /*batch_size=*/2);
AddTestTensor("weights", {2}, /*batch_size=*/2);
RunValidationAndConversion(
node_def, error::UNIMPLEMENTED,
"Inputs must have the same rank if they are both tensors.");
}
{
// Make sure that INT8 mode doesn't try to use IFullyConnectedLayer when not
// compatible. In this case we can't use FC because transpose_a is true.
Reset(TrtPrecisionMode::INT8);
NodeDef node_def = get_batch_matmul_nodedef(DT_FLOAT, true, false);
AddTestTensor("input", {1, 2, 2});
AddTestWeights<float>("weights", {2, 2}, {0, 1, 2, 3});
RunValidationAndConversion(node_def);
CheckAddedLayers<nvinfer1::IMatrixMultiplyLayer>(this, true);
CheckAddedLayers<nvinfer1::IFullyConnectedLayer>(this, false);
}
// We derive test data from the MatMul test params by adding extra leading
// dimensions.
std::vector<MatMulTestParams> params_2d = GetMatMulTestParams();
std::vector<MatMulTestParams> params;
params.reserve(params_2d.size() * 3 + 1);
for (bool transpose_a : {false, true}) {
for (bool transpose_b : {false, true}) {
Reset();
NodeDef node_def =
get_batch_matmul_nodedef(DT_FLOAT, transpose_a, transpose_b);
AddTestTensor("input", {2, 2}, /*batch_size=*/1);
AddTestWeights<float>("weights", {1, 2, 2}, {1, 2, 3, 4});
auto insert_ones = [](std::vector<int> v, int n) {
std::vector<int> ones(n, 1);
ones.insert(ones.end(), v.begin(), v.end());
return ones;
};
RunValidationAndConversion(node_def);
TRT_TensorOrWeights output;
TF_EXPECT_OK(GetTensorOrWeights("my_matmul", &output));
ASSERT_TRUE(output.is_tensor());
ExpectTrtDimsEqualsArray({2, 2}, output.tensor()->getDimensions());
const DataVec input_data{{"input", AsTensor<float>({0, 1, 2, 3})}};
DataVec output_data{{"my_matmul", ConstructTensor<float>(4)}};
TF_EXPECT_OK(BuildAndRun(input_data, &output_data));
if (!transpose_a && !transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
ElementsAre(3, 4, 11, 16));
} else if (transpose_a && transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
ElementsAre(4, 8, 7, 15));
} else if (transpose_a) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
ElementsAre(6, 8, 10, 14));
} else if (transpose_b) {
EXPECT_THAT(GetSpanForData<float>(output_data[0]),
ElementsAre(2, 4, 8, 18));
}
}
}
// Add a leading 1 dimension to A, B and result.
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
[](MatMulTestParams p) {
p.shape_a.insert(p.shape_a.begin(), 1);
p.shape_b.insert(p.shape_b.begin(), 1);
p.expected_shape.insert(p.expected_shape.begin(), 1);
return p;
});
TestMatMulHelper(this, get_batch_matmul_nodedef, "BatchMatMul");
// Test with N > 1: weights cannot be batched in implicit batch mode.
// clang-format off
params.push_back(
MatMulTestParams{{2, 2, 2}, {0, 1, 2, 3, 0, 1, 2, 3}, false, // A
{2, 2, 2}, {0, 1, 2, 3, 0, 1, 2, 3}, false, // B
{2, 2, 2}, {2, 3, 6, 11, 2, 3, 6, 11}} // result
);
params.push_back(
MatMulTestParams{{2, 2, 3}, {0, 1, 2, 3, 4, 5, 0, 1, 2, 3, 4, 5},
false,
{2, 2, 3}, {1, 2, 3, 4, 5, 6, 1, 2, 3, 4, 5, 6}, true,
{2, 2, 2}, {8, 17, 26, 62, 8, 17, 26, 62}});
// clang-format on
// Add two leading 1 dimensions to A, B and result.
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
[insert_ones](MatMulTestParams p) {
p.shape_a = insert_ones(p.shape_a, 2);
p.shape_b = insert_ones(p.shape_b, 2);
p.expected_shape = insert_ones(p.expected_shape, 2);
return p;
});
// Test broadcast: add two leading 1 dimensions to A, but not to B.
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
[insert_ones](MatMulTestParams p) {
p.shape_a = insert_ones(p.shape_a, 2);
p.expected_shape = insert_ones(p.expected_shape, 2);
return p;
});
// Test broadcast: add a leading 1 dimension to A and two leading 1s to B.
// Broadcasting A need a dynamic brodacast which will be incompatible with
// FC layer.
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
[insert_ones](MatMulTestParams p) {
p.shape_a = insert_ones(p.shape_a, 1);
p.shape_b = insert_ones(p.shape_b, 2);
p.expected_shape = insert_ones(p.expected_shape, 2);
return p;
});
// Test with N > 1: since weights cannot be batched in implicit batch mode.
// We tests with batch size 2.
std::transform(params_2d.begin(), params_2d.end(), std::back_inserter(params),
[insert_ones](MatMulTestParams p) {
p.shape_a.insert(p.shape_a.begin(), 2);
p.values_a.reserve(p.values_a.size() * 2);
p.values_a.insert(p.values_a.end(), p.values_a.begin(),
p.values_a.end());
p.shape_b.insert(p.shape_b.begin(), 2);
p.values_b.reserve(p.values_b.size() * 2);
p.values_b.insert(p.values_b.end(), p.values_b.begin(),
p.values_b.end());
p.expected_shape.insert(p.expected_shape.begin(), 2);
p.expected_output.reserve(p.expected_output.size() * 2);
p.expected_output.insert(p.expected_output.end(),
p.expected_output.begin(),
p.expected_output.end());
return p;
});
TestMatMulHelper(this, get_batch_matmul_nodedef, params);
}
TEST_P(OpConverter_FP32_FP16_Test, ConvertBiasAdd) {

View File

@ -247,7 +247,9 @@ Status SetTrtEngineOutputs(nvinfer1::ICudaEngine* cuda_engine,
bool status = output_tensor->CopyFrom(*output_tensor, output_shape);
if (!status) {
return errors::Internal(
"Buffer size do not match while reshaping output tensors");
"Buffer size (", output_tensor->NumElements(),
") do not match while reshaping output tensors to shape ",
output_shape.DebugString());
}
}

View File

@ -110,13 +110,14 @@ namespace {
// Log just once by default (on default log level), and let the user adjust
// the log level for more detailed logging.
void LogAtLeastOnce(const std::string& log_message) {
if (VLOG_IS_ON(1)) {
VLOG(1) << log_message;
} else {
LOG_FIRST_N(INFO, 1) << log_message;
#define LOG_AT_LEAST_ONCE(log_message) \
{ \
if (VLOG_IS_ON(1)) { \
VLOG(1) << log_message; \
} else { \
LOG_FIRST_N(INFO, 1) << log_message; \
} \
}
}
} // namespace
@ -132,18 +133,19 @@ Status MlirBridgePass::Run(const ConfigProto& config_proto,
// based on the devices in the module.
if (GetPassState(/*device_set=*/nullptr, config_proto, graph) ==
MlirOptimizationPassState::Disabled) {
LogAtLeastOnce("Skipping MLIR TPU Bridge, session flag not enabled");
LOG_AT_LEAST_ONCE("Skipping MLIR TPU Bridge, session flag not enabled");
mlir_bridge_gauge_v2->GetCell()->Set(false);
return Status::OK();
}
// Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
if (!HasTPUDevicesAndOps(module)) {
LogAtLeastOnce("Skipping MLIR TPU Bridge, no TPU devices or TPU ops found");
LOG_AT_LEAST_ONCE(
"Skipping MLIR TPU Bridge, no TPU devices or TPU ops found");
return Status::OK();
}
LogAtLeastOnce("Running MLIR TPU Bridge");
LOG_AT_LEAST_ONCE("Running MLIR TPU Bridge");
mlir_bridge_gauge_v2->GetCell()->Set(true);
TF_RETURN_IF_ERROR(
@ -176,7 +178,7 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
// based on the devices in the module.
if (!IsEnabled(/*device_set=*/nullptr, options.session_options->config,
**options.graph)) {
LogAtLeastOnce(
LOG_AT_LEAST_ONCE(
"Skipping MLIR TPU Bridge V1 Compat, session flag not enabled");
mlir_bridge_gauge_v1->GetCell()->Set(false);
return Status::OK();
@ -184,12 +186,12 @@ Status MlirBridgeV1CompatPass::Run(const GraphOptimizationPassOptions& options,
// Skip MLIR TPU Bridge if no TPU devices or TPU ops found.
if (!HasTPUDevicesAndOps(module)) {
LogAtLeastOnce(
LOG_AT_LEAST_ONCE(
"Skipping MLIR TPU Bridge V1 Compat, no TPU devices or TPU ops found");
return Status::OK();
}
LogAtLeastOnce("Running MLIR TPU Bridge V1 Compat");
LOG_AT_LEAST_ONCE("Running MLIR TPU Bridge V1 Compat");
mlir_bridge_gauge_v1->GetCell()->Set(true);
TF_RETURN_IF_ERROR(

View File

@ -277,9 +277,23 @@ def conv(lhs,
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
if preferred_element_type is None:
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
return gen_xla_ops.xla_conv_v2(
if needs_v2:
return gen_xla_ops.xla_conv_v2(
lhs,
rhs,
window_strides=window_strides,
padding=padding,
lhs_dilation=lhs_dilation,
rhs_dilation=rhs_dilation,
feature_group_count=feature_group_count,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name)
return gen_xla_ops.xla_conv(
lhs,
rhs,
window_strides=window_strides,
@ -289,7 +303,6 @@ def conv(lhs,
feature_group_count=feature_group_count,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name)
@ -309,14 +322,22 @@ def dot_general(lhs,
precision_config_proto = ""
if precision_config:
precision_config_proto = precision_config.SerializeToString()
needs_v2 = preferred_element_type or (lhs.dtype != rhs.dtype)
if preferred_element_type is None:
preferred_element_type = np_utils.result_type(lhs.dtype, rhs.dtype)
return gen_xla_ops.xla_dot_v2(
if needs_v2:
return gen_xla_ops.xla_dot_v2(
lhs,
rhs,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name)
return gen_xla_ops.xla_dot(
lhs,
rhs,
dimension_numbers=dimension_numbers.SerializeToString(),
precision_config=precision_config_proto,
preferred_element_type=preferred_element_type,
name=name)

View File

@ -804,9 +804,12 @@ Status XlaCompiler::CompileFunction(
}
VLOG(1) << "====================================================";
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(
*graph, config_proto,
/*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
MlirBridgeRolloutPolicy policy = MlirBridgeRolloutPolicy::kDisabledByUser;
if (options.is_entry_computation) {
policy = GetMlirBridgeRolloutPolicy(
*graph, config_proto,
/*uses_uninitialized_resource_args=*/AnyUninitializedResourceArg(args));
}
if (policy == MlirBridgeRolloutPolicy::kEnabledByUser) {
VLOG(1) << "Using MLIR bridge";
GraphDebugInfo debug_info;

View File

@ -89,10 +89,16 @@ class XlaOpKernelContext {
// xla::PRIMITIVE_TYPE_INVALID.
xla::PrimitiveType InputXlaType(absl::string_view name);
// Returns the shape of input `index`.
// Returns the shape of input at `index` or input the given `name`. Note that
// in case the shape of the input is not static, then the returned shape has
// bounds as the dimension size instead of having unknown dimensions. Use
// InputXlaShape instead that provides shapes with dynamism information.
//
ABSL_DEPRECATED(
"Prefer InputXlaShape which handles dynamic shapes accurately.")
TensorShape InputShape(int index);
// Returns the shape of input with name `name`.
ABSL_DEPRECATED(
"Prefer InputXlaShape which handles dynamic shapes accurately.")
TensorShape InputShape(absl::string_view name);
// Returns input `index` as a XlaOp. Unlike

View File

@ -130,7 +130,7 @@ LocalExecutable::RunHelper(const absl::Span<const Shape* const> argument_shapes,
}
for (int i = 0, end = argument_shapes.size(); i < end; ++i) {
if (!computation_layout.parameter_layout(i).MatchesLayoutInShape(
*argument_shapes[i], /*minor_to_major_only=*/true)) {
*argument_shapes[i])) {
return InvalidParameterArgument(
executable_.get(), i,
"Argument does not match host shape or layout of computation "
@ -175,7 +175,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::Run(
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(arguments.size());
for (const ShapedBuffer* const arg : arguments) {
argument_shapes.push_back(&arg->on_device_shape());
argument_shapes.push_back(&arg->on_host_shape());
}
return AsyncCallAndBlockHostUntilDone<xla::ScopedShapedBuffer>(
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
@ -188,7 +188,7 @@ StatusOr<ExecutionOutput> LocalExecutable::Run(
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(arguments.size());
for (const ExecutionInput& arg : arguments) {
argument_shapes.push_back(&arg.shape());
argument_shapes.push_back(&arg.host_shape());
}
return AsyncCallAndBlockHostUntilDone<ExecutionOutput>(
argument_shapes, run_options, [&](const ExecutableRunOptions& options) {
@ -243,7 +243,7 @@ StatusOr<ScopedShapedBuffer> LocalExecutable::RunAsync(
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(arguments.size());
for (const ShapedBuffer* const arg : arguments) {
argument_shapes.push_back(&arg->on_device_shape());
argument_shapes.push_back(&arg->on_host_shape());
}
TF_ASSIGN_OR_RETURN(auto options_and_stream,
RunHelper(argument_shapes, run_options));
@ -324,7 +324,7 @@ StatusOr<ExecutionOutput> LocalExecutable::RunAsync(
std::vector<const Shape*> argument_shapes;
argument_shapes.reserve(arguments.size());
for (const ExecutionInput& arg : arguments) {
argument_shapes.push_back(&arg.shape());
argument_shapes.push_back(&arg.host_shape());
}
return RunAsync(argument_shapes, std::move(arguments), run_options);
}

View File

@ -173,6 +173,12 @@ static void AllocateFlags() {
return true;
};
// Custom "sub-parser" lambda for xla_gpu_llvm_ir_file.
auto setter_for_xla_gpu_llvm_ir_file = [](const string& value) {
flag_values->add_xla_gpu_llvm_ir_file(value);
return true;
};
// Custom "sub-parser" lambda for xla_backend_extra_options.
auto setter_for_xla_backend_extra_options =
[](string comma_separated_values) {
@ -370,7 +376,15 @@ static void AllocateFlags() {
"If non-empty, specifies a file containing ptx to use. The filename "
"prefix must have the same pattern as PTX dumped by XLA. This allows to "
"match one specific module. General workflow. Get the generated module "
"ptx from XLA. Modify it. Then pass it back via this option."));
"ptx from XLA, modify it, then pass it back via this option."));
flag_objects->push_back(tensorflow::Flag(
"xla_gpu_llvm_ir_file", setter_for_xla_gpu_llvm_ir_file, "",
"If non-empty, specifies a file containing textual LLVM IR to use. The "
"filename prefix must have the same pattern as LLVM dumped by XLA "
"(i.e. module_0001.ir-no-opt.ll -> module_0001.MY_NEW_FILE.ll). This "
"allows to match one specific module. General workflow. Get the not "
"optimized LLVM IR from XLA, modify it, then pass it back via this "
"option."));
flag_objects->push_back(tensorflow::Flag(
"xla_test_all_output_layouts",
bool_setter_for(&DebugOptions::set_xla_test_all_output_layouts),

View File

@ -1368,8 +1368,7 @@ array with the same shape. It is allowed for `operand` to be a scalar (rank 0).
The XLA FFT operation implements the forward and inverse Fourier Transforms for
real and complex inputs/outputs. Multidimensional FFTs on up to 3 axes are
supported, except on TPU, where only a single axis is supported (please file a
GitHub issue if you require higher order).
supported.
See also
[`XlaBuilder::Fft`](https://www.tensorflow.org/code/tensorflow/compiler/xla/client/xla_builder.h).

View File

@ -145,14 +145,14 @@ class LocalDeviceState {
// thread or on the worker thread (depending on thread schedules), not a
// device callback, so it is safe if the destructor frees device resource
// (e.g., GPU objects).
// TODO(phawkins): use move-capture when we can use C++14 features.
template <typename T>
void ThenRelease(se::Stream* stream, T object) const {
void ThenRelease(se::Stream* stream, T&& object) const {
if (callback_stream_.get() != stream) {
callback_stream_->ThenWaitFor(stream);
}
ThenExecuteOnCallbackThread(callback_stream_.get(),
[object]() { /* releases object */ });
ThenExecuteOnCallbackThread(
callback_stream_.get(),
[object = std::forward<T>(object)]() { /* releases object */ });
}
Semaphore& compute_semaphore() { return compute_semaphore_; }

View File

@ -1340,7 +1340,7 @@ PjRtStreamExecutorBuffer::CopyToDeviceHelper(
// have completed.
tensorflow::down_cast<PjRtStreamExecutorDevice*>(device_)
->local_device_state()
->ThenRelease(transfer_stream, src_device_buffer);
->ThenRelease(transfer_stream, std::move(src_device_buffer));
}
return copy_event_or.status();
}

View File

@ -485,12 +485,13 @@ class CopyRemover {
};
CopyRemover(const HloModule& module, const HloAliasAnalysis& alias_analysis,
const HloOrdering& ordering)
const HloOrdering& ordering, bool check_live_range_ordering)
: dataflow_(alias_analysis.dataflow_analysis()), ordering_(ordering) {
// Construct a list for each HLO buffer in the alias analysis. Maintain a
// map from HloValue to the respective list element representing that
// value. The map is used to construct the copy info map below.
absl::flat_hash_map<const HloValue*, ValueNode*> value_to_node;
// Perform check only if the default dependence-based ordering is used.
for (const HloBuffer& buffer : alias_analysis.buffers()) {
// No copies should have been inserted within fused computations, so no
// need to remove them. HloOrdering isn't compatible with HloValues inside
@ -498,24 +499,26 @@ class CopyRemover {
if (buffer.values().at(0)->defining_instruction()->IsFused()) {
continue;
}
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
// interference. Specifically, the addition of the control flow edges
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
if (value_a->shape().IsToken()) {
// Token values have no representation and cannot interfere.
continue;
}
for (const HloValue* value_b : buffer.values()) {
if (value_a != value_b) {
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
dataflow_) ||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
dataflow_))
<< value_a->ToShortString() << " and "
<< value_b->ToShortString() << " are not ordered";
if (check_live_range_ordering) {
// Verify values contained in the buffer are strictly ordered. This
// should always be the case after adding copies to eliminate
// interference. Specifically, the addition of the control flow edges
// between copies added around aliased operations (kWhile) guarantees
// this strict order.
for (const HloValue* value_a : buffer.values()) {
if (value_a->shape().IsToken()) {
// Token values have no representation and cannot interfere.
continue;
}
for (const HloValue* value_b : buffer.values()) {
if (value_a != value_b) {
DCHECK(ordering_.LiveRangeStrictlyBefore(*value_a, *value_b,
dataflow_) ||
ordering_.LiveRangeStrictlyBefore(*value_b, *value_a,
dataflow_))
<< value_a->ToString() << " and " << value_b->ToString()
<< " are not ordered";
}
}
}
}
@ -729,27 +732,31 @@ class CopyRemover {
VLOG(2) << copy->name() << " defines the first value in its buffer";
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
next_dest = Next(*next_dest)) {
// Live range of 'from' value (s_x) must be before 'next_dest' (d_1);
if (!LiveRangeBefore(*src, *next_dest)) {
VLOG(2) << "Not removing the copy: live range of "
<< src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
// Live range of (s_x, s_{x-1},...) must be before 'next_dest' (d_1);
for (ValueNode* prev_src = src; prev_src != nullptr;
prev_src = Prev(*prev_src)) {
if (!LiveRangeBefore(*prev_src, *next_dest)) {
VLOG(2) << "Not removing the copy: live range of "
<< prev_src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
}
}
}
for (ValueNode* next_src = Next(*src); next_src != nullptr;
next_src = Next(*next_src)) {
// Live range of 'last_dest' (d_m) must be before 'next_src' s_{x+1}.
ValueNode* last_dest = dest->prev;
DCHECK(IsTail(*last_dest));
if (!LiveRangeBefore(*last_dest, *next_src)) {
VLOG(2) << "Not removing the copy: live range of "
<< last_dest->value->ToShortString() << " is not before "
<< next_src->value->ToShortString();
return false;
for (ValueNode* last_dest = dest->prev; last_dest != nullptr;
last_dest = Prev(*dest)) {
if (!LiveRangeBefore(*last_dest, *next_src)) {
VLOG(2) << "Not removing the copy: live range of "
<< last_dest->value->ToShortString() << " is not before "
<< next_src->value->ToShortString();
return false;
}
}
}
VLOG(2) << "Splice dest after source.";
// Splice in destination buffer values list right after 'src'.
SpliceAfter(dest, src);
} else if (IsTail(*src)) {
@ -769,32 +776,36 @@ class CopyRemover {
VLOG(2) << copy->name() << " copies the last value ("
<< src->value->ToShortString() << ") in its buffer";
ValueNode* first_src = src->next;
DCHECK(IsHead(*first_src));
for (ValueNode* prev_dest = Prev(*dest);
// nullptr condition handled above in the first 'if' case.
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
if (!LiveRangeBefore(*prev_dest, *first_src)) {
// Live range of value d_{y-1} is not before s_0.
VLOG(2) << "Not removing the copy: live range of "
<< prev_dest->value->ToShortString() << " is not before "
<< first_src->value->ToShortString();
return false;
for (ValueNode* next_src = src->next; next_src != nullptr;
next_src = Next(*next_src)) {
for (ValueNode* prev_dest = Prev(*dest);
// nullptr condition handled above in the first 'if' case.
prev_dest != nullptr; prev_dest = Prev(*prev_dest)) {
if (!LiveRangeBefore(*prev_dest, *next_src)) {
// Live range of value d_{y-1} is not before s_0.
VLOG(2) << "Not removing the copy: live range of "
<< prev_dest->value->ToShortString() << " is not before "
<< next_src->value->ToShortString();
return false;
}
}
}
for (ValueNode* next_dest = Next(*dest); next_dest != nullptr;
next_dest = Next(*next_dest)) {
if (!LiveRangeBefore(*src, *next_dest)) {
// Live range of value s_n is not before d_{y+1}.
VLOG(2) << "Not removing the copy: live range of "
<< src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
for (ValueNode* prev_src = src; prev_src != nullptr;
prev_src = Prev(*prev_src)) {
if (!LiveRangeBefore(*prev_src, *next_dest)) {
// Live range of value s_n is not before d_{y+1}.
VLOG(2) << "Not removing the copy: live range of "
<< prev_src->value->ToShortString() << " is not before "
<< next_dest->value->ToShortString();
return false;
}
}
}
VLOG(2) << "Splice src after prev of dest.";
// Splice source buffer values list right after 'prev_dest'.
SpliceAfter(first_src, Prev(*dest));
SpliceAfter(src->next, Prev(*dest));
} else {
VLOG(2) << copy->name()
<< " copies value in middle of source buffer to value in middle "
@ -1175,12 +1186,14 @@ static int64 GetNumExistingCopies(const HloModule* module) {
}
Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module) {
HloModule* module,
bool check_live_range_ordering) {
XLA_VLOG_LINES(4, module->ToString());
TF_ASSIGN_OR_RETURN(std::unique_ptr<HloAliasAnalysis> alias_analysis,
HloAliasAnalysis::Run(module, can_share_buffer_));
CopyRemover copy_remover(*module, *alias_analysis, ordering);
CopyRemover copy_remover(*module, *alias_analysis, ordering,
check_live_range_ordering);
if (VLOG_IS_ON(3)) {
LOG(INFO) << "Removing unnecessary copies in " << module->name();
LOG(INFO) << "Buffer values, in dependency order: ";
@ -1200,7 +1213,9 @@ Status CopyInsertion::RemoveUnnecessaryCopies(const HloOrdering& ordering,
VLOG(2) << "Running fixpoint iteration " << num_iterations
<< " of copy elision";
for (HloComputation* computation : module->computations()) {
VLOG(2) << "computation:" << computation->name() << "\n";
for (HloInstruction* instruction : computation->instructions()) {
VLOG(2) << instruction->ToString() << "\n";
if (instruction->opcode() == HloOpcode::kCopy &&
copy_remover.TryElideCopy(instruction)) {
changed = true;
@ -1260,7 +1275,8 @@ StatusOr<bool> CopyInsertion::Run(HloModule* module) {
name(), "after adding copies to resolve interference", *module);
TF_RETURN_IF_ERROR(
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module));
RemoveUnnecessaryCopies(DependencyHloOrdering(module), module,
/*check_live_range_ordering=*/true));
DumpHloModuleDuringPassIfEnabled(name(), "after removing unnecessary copies",
*module);
TF_RETURN_IF_ERROR(AddSpecialCaseCopies(*call_graph, module));

View File

@ -63,8 +63,10 @@ class CopyInsertion : public HloModulePass {
// Try to remove as many copies from the module as possible without
// introducing live range interference. Only copy instructions that are
// eligible for copy elision are considered for removal.
Status RemoveUnnecessaryCopies(const HloOrdering& ordering,
HloModule* module);
// If check_live_range_ordering is true, check that live ranges are ordered
// in all the existing aliased buffers.
Status RemoveUnnecessaryCopies(const HloOrdering& ordering, HloModule* module,
bool check_live_range_ordering = false);
// Add copies to address special constraints on the roots of computations not
// related to live range interference:

View File

@ -2877,6 +2877,56 @@ ENTRY main {
EXPECT_EQ(CountCopies(*module), 1);
}
TEST_F(CopyInsertionTest, HorizontalLoopFusionNoCopy) {
const string& hlo_string = R"(
HloModule test
fused_computation {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
add0 = f32[10, 20] add(p0, p1)
sub0 = f32[10, 10] subtract(p2, p3)
reshape0 = f32[200] reshape(add0)
reshape1 = f32[100] reshape(sub0)
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
slice0 = f32[200] slice(concat0), slice={[0:200]}
slice1 = f32[100] slice(concat0), slice={[200:300]}
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
}
ENTRY test {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
gte0 = f32[200] get-tuple-element(fusion), index=0
gte1 = f32[100] get-tuple-element(fusion), index=1
bitcast0 = f32[10,20] bitcast(gte0)
bitcast1 = f32[10,10] bitcast(gte1)
ROOT tuple = (f32[10,20], f32[10,10]) tuple(bitcast0, bitcast1)
}
)";
TF_ASSERT_OK_AND_ASSIGN(std::unique_ptr<HloModule> module,
ParseAndReturnVerifiedModule(hlo_string));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{0},
/*param_number=*/0,
/*param_index=*/{}));
ASSERT_IS_OK(module->input_output_alias_config().SetUpAlias(
/*output_index=*/{1},
/*param_number=*/3,
/*param_index=*/{}));
InsertCopies(module.get());
// There should be no copies inserted.
EXPECT_EQ(CountCopies(*module), 0);
}
TEST_F(CopyInsertionTest, NestedWhileAndConditional3) {
const string& hlo_string = R"(
HloModule TestModule

View File

@ -1405,6 +1405,8 @@ cc_library(
"//tensorflow/stream_executor:stream_executor_headers",
"//tensorflow/stream_executor/cuda:cuda_diagnostics",
"//tensorflow/stream_executor/gpu:asm_compiler",
"@llvm-project//llvm:IRReader",
"@llvm-project//llvm:Support",
]) + ["//tensorflow/stream_executor/gpu:gpu_driver_header"],
)

View File

@ -174,14 +174,6 @@ bool IsProfitableFusionCandidate(const HloInstruction& instr) {
return false;
}
// We can emit DUS in-place, horizontally fusing it makes the emitter no
// longer recognize that it can be done in-place. This creates much slower
// code. This restriction could be lifted if buffer assignment would recognize
// that the DUS can be done in-place even inside of a horizontal fusion.
if (root->opcode() == HloOpcode::kDynamicUpdateSlice) {
return false;
}
return true;
}
@ -203,6 +195,19 @@ bool HasOnlyRowMajorLayout(const HloInstruction& fusion_instr) {
return true;
}
// Returns whether any operand of `instr` is a parameter instruction that
// is shared with `fusion_instrs`.
bool AnyOpndIsParamSharedAmongFusions(
const HloInstruction* instr,
const absl::flat_hash_set<HloInstruction*>& fusion_instrs) {
return absl::c_any_of(instr->operands(), [&](const HloInstruction* opnd) {
return opnd->opcode() == HloOpcode::kParameter &&
absl::c_any_of(opnd->users(), [&](const HloInstruction* user) {
return user != instr && fusion_instrs.contains(user);
});
});
}
void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
HloInstruction* consumer) {
// First, find out all fusion instructions. We will filter out
@ -230,6 +235,14 @@ void HorizontalLoopFusionImpl::FusionCandidates::Initialize(
} else if (!HasOnlyRowMajorLayout(*instr)) {
VLOG(2) << "Reject non-row-major fusion instr " << instr->ToString();
continue;
} else if (AnyOpndIsParamSharedAmongFusions(instr, fusion_instrs)) {
// Don't fuse fusions whose operands are parameter instructions that are
// shared among fusions because we cannot i/o alias the produced
// horizontal fusion due to the concat insertion.
VLOG(2) << "Reject the fusion instr because it shares parameter with"
<< " other fusion candidates, instr: ",
instr->ToString();
continue;
} else {
VLOG(2) << "Find a fusion candidate " << instr->ToString();
fusion_instrs_.push_back(instr);

View File

@ -364,33 +364,33 @@ TEST_F(HorizontalLoopFusionTest, RMSPropLike) {
EXPECT_TRUE(RunAndCompare(std::move(module), ErrorSpec{1.0e-5, 1.0e-5}));
}
TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
TEST_F(HorizontalLoopFusionTest, DynamicUpdateSlice) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule NegativeTestForDynamicUpdateSlice
fusion.1 {
p.0 = f16[5,9,10]{2,1,0} parameter(0)
p.1 = s32[1]{0} parameter(1)
p.1 = s32[] parameter(1)
p.2 = f16[1,9,10]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice =
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}
fusion.2 {
p.0 = f16[5,9,10]{2,1,0} parameter(0)
p.1 = s32[1]{0} parameter(1)
p.1 = s32[] parameter(1)
p.2 = f16[1,9,10]{2,1,0} parameter(2)
c.0 = s32[] constant(0)
pad = s32[3]{0} pad(p.1, c.0), padding=0_2
ROOT %dynamic-update-slice = f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, pad)
ROOT %dynamic-update-slice =
f16[5,9,10]{2,1,0} dynamic-update-slice(p.0, p.2, p.1, c.0, c.0)
}
ENTRY entry {
p.00 = f16[5,9,10]{2,1,0} parameter(0)
p.01 = f16[5,9,10]{2,1,0} parameter(1)
p.10 = s32[1]{0} parameter(2)
p.11 = s32[1]{0} parameter(3)
p.10 = s32[] parameter(2)
p.11 = s32[] parameter(3)
p.20 = f16[1,9,10]{2,1,0} parameter(4)
p.21 = f16[1,9,10]{2,1,0} parameter(5)
@ -400,6 +400,46 @@ TEST_F(HorizontalLoopFusionTest, NegativeTestForDynamicUpdateSlice) {
})")
.ValueOrDie();
EXPECT_TRUE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
EXPECT_TRUE(HloDCE().Run(module.get()).ValueOrDie());
VLOG(2) << "Dump after horizontal fusion:";
VLOG(2) << module->ToString();
EXPECT_TRUE(RunAndCompareNoHloPasses(std::move(module), ErrorSpec{0, 0}));
}
TEST_F(HorizontalLoopFusionTest, NegativeTestForSharedParam) {
auto module = ParseAndReturnVerifiedModule(R"(
HloModule BasicTest
fused_computation.1 {
arg.1 = f16[123]{0} parameter(0)
arg.2 = f16[123]{0} parameter(1)
ROOT mul.1 = f16[123]{0} multiply(arg.1, arg.2)
}
fused_computation.2 {
arg.1 = f16[123]{0} parameter(0)
arg.2 = f16[123]{0} parameter(1)
ROOT add.1 = f16[123]{0} add(arg.1, arg.2)
}
ENTRY entry_computation {
arg.1 = f16[123]{0} parameter(0)
// arg.2 is shared by fusion.1 and fusion.2
arg.2 = f16[123]{0} parameter(1)
arg.3 = f16[123]{0} parameter(2)
fusion.1 = f16[123]{0}
fusion(arg.1, arg.2), kind=kLoop, calls=fused_computation.1
fusion.2 = f16[123]{0}
fusion(arg.3, arg.2), kind=kLoop, calls=fused_computation.2
ROOT tuple.1 = (f16[123]{0}, f16[123]{0})
tuple(fusion.1, fusion.2)
}
)")
.ValueOrDie();
EXPECT_FALSE(GpuHorizontalLoopFusion().Run(module.get()).ValueOrDie());
}

View File

@ -299,6 +299,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
(f64_atomic_add_supported && element_type == F64);
if (atomic_add_supported) {
AtomicRMW(llvm::AtomicRMWInst::FAdd, output_address, source,
llvm::MaybeAlign(),
llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@ -307,6 +308,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
if (is_atomic_integral) {
// integral + integral
AtomicRMW(llvm::AtomicRMWInst::Add, output_address, source,
llvm::MaybeAlign(),
llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@ -318,7 +320,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Max
: llvm::AtomicRMWInst::UMax;
AtomicRMW(opcode, output_address, source,
AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@ -328,7 +330,7 @@ bool IrEmitter::MaybeEmitDirectAtomicOperation(
auto opcode = primitive_util::IsSignedIntegralType(element_type)
? llvm::AtomicRMWInst::Min
: llvm::AtomicRMWInst::UMin;
AtomicRMW(opcode, output_address, source,
AtomicRMW(opcode, output_address, source, llvm::MaybeAlign(),
llvm::AtomicOrdering::SequentiallyConsistent);
return true;
}
@ -476,10 +478,10 @@ Status IrEmitter::EmitAtomicOperationUsingCAS(const HloComputation& computation,
// Emit code to perform the atomicCAS operation
// (cas_old_output, success) = atomicCAS(memory_address, cas_old_output,
// cas_new_output);
llvm::Value* ret_value =
AtomicCmpXchg(atomic_memory_address, cas_old_output, cas_new_output,
llvm::AtomicOrdering::SequentiallyConsistent,
llvm::AtomicOrdering::SequentiallyConsistent);
llvm::Value* ret_value = AtomicCmpXchg(
atomic_memory_address, cas_old_output, cas_new_output, llvm::MaybeAlign(),
llvm::AtomicOrdering::SequentiallyConsistent,
llvm::AtomicOrdering::SequentiallyConsistent);
// Extract the memory value returned from atomicCAS and store it as
// cas_old_output.

View File

@ -20,6 +20,8 @@ limitations under the License.
#include <fstream>
#include "absl/base/call_once.h"
#include "llvm/IRReader/IRReader.h"
#include "llvm/Support/SourceMgr.h"
#include "tensorflow/compiler/xla/service/algebraic_simplifier.h"
#include "tensorflow/compiler/xla/service/dump.h"
#include "tensorflow/compiler/xla/service/gpu/cublas_gemm_pad_for_tensor_cores.h"
@ -202,7 +204,7 @@ absl::optional<bool> CanShareBufferHint(const HloInstruction* user,
// Try to load ptx from files defined in the FLAGS. If successful, return true.
bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
const HloModule* module, std::string* ptx) {
// If the xla_gpu_ptx_file options is set, be explicit when a file is used
// If the xla_gpu_ptx_file option is set, be explicit if a file is used
// and warn when a file is not used to ease catching typo in filename.
std::string prefix = xla::FilenameFor(*module, "", *ptx);
std::string matched_filename;
@ -234,6 +236,50 @@ bool MaybeLoadPtxFromFile(const HloModuleConfig module_config,
return false;
}
// Try to load textual LLVM IR from files defined in the FLAGS. If
// successful, return the llvm::Module, otherwise return nullptr.
std::unique_ptr<llvm::Module> MaybeLoadLLVMFromFile(const HloModule* module,
llvm::Module* llvm_module) {
// If the xla_gpu_llvm_ir_file option is set, be explicit if a file is used
// and warn when a file is not used to ease catching typo in filename.
if (module == nullptr) {
return nullptr;
}
std::string prefix = xla::FilenameFor(*module, "", "");
auto xla_gpu_llvm_ir_file =
module->config().debug_options().xla_gpu_llvm_ir_file();
auto matched_filename = absl::c_find_if(
xla_gpu_llvm_ir_file, [prefix](const string& full_filename) {
// To ease comparing many LLVM versions, accept different suffixes then
// the original filename.
return absl::StartsWith(tensorflow::io::Basename(full_filename),
prefix);
});
if (!xla_gpu_llvm_ir_file.empty() &&
matched_filename == std::end(xla_gpu_llvm_ir_file)) {
VLOG(0) << "RunBackend() - For module with prefix '" << prefix
<< "', we did not found a LLVM file to load.";
}
if (matched_filename != std::end(xla_gpu_llvm_ir_file)) {
VLOG(0) << "RunBackend() - Will load LLVM from file: " << *matched_filename;
llvm::LLVMContext& context = llvm_module->getContext();
llvm::SMDiagnostic err;
std::unique_ptr<llvm::Module> loaded_module =
llvm::parseIRFile(*matched_filename, err, context);
if (!loaded_module) {
err.print("ERR", llvm::errs());
LOG(FATAL) << "Failed to load an LLVM file. It is probably invalid LLVM.";
}
// Overwrite the dumped not optimized LLVM to show which one will be used.
llvm_ir::DumpIrIfEnabled(*module, *loaded_module, /*optimized=*/false);
return loaded_module;
}
return nullptr;
}
} // namespace
// Prints a warning if the ptx->sass JIT in the driver has known bugs.
@ -320,13 +366,21 @@ NVPTXCompiler::CompileTargetBinary(const HloModuleConfig& module_config,
libdevice_dir = cached_libdevice_dir_;
}
VLOG(2) << "Libdevice dir = " << libdevice_dir << "\n";
std::unique_ptr<llvm::Module> loaded_module =
MaybeLoadLLVMFromFile(debug_module, llvm_module);
llvm::Module* selected_module = nullptr;
if (loaded_module) {
selected_module = loaded_module.get();
} else {
selected_module = llvm_module;
}
string ptx;
if (!(debug_module &&
MaybeLoadPtxFromFile(module_config, debug_module, &ptx))) {
XLA_SCOPED_LOGGING_TIMER(
"NVPTXCompiler::CompileTargetBinary - CompileToPtx");
TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(llvm_module, gpu_version,
TF_ASSIGN_OR_RETURN(ptx, nvptx::CompileToPtx(selected_module, gpu_version,
module_config, libdevice_dir));
}

View File

@ -120,6 +120,175 @@ bool HloDataflowAnalysis::AreTransitiveUsesElementwiseOrTuple(
return true;
}
namespace {
bool Is1dSliceWithoutStrides(const HloInstruction* instr) {
return instr->opcode() == HloOpcode::kSlice &&
1 == instr->slice_starts().size() &&
1 == instr->slice_limits().size() &&
1 == instr->slice_strides().size() &&
1 == instr->slice_strides().at(0);
}
bool IsSliceInputFusion(const HloInstruction& unnested_hlo) {
if (!unnested_hlo.IsInputFusion()) {
return false;
}
const HloInstruction* root = unnested_hlo.fused_expression_root();
if (root->opcode() != HloOpcode::kTuple) {
return false;
}
return absl::c_all_of(root->operands(), [](const HloInstruction* instr) {
return Is1dSliceWithoutStrides(instr);
});
}
struct ConcatUsageInfo {
// Pointer to a previously seen concat. nullptr if no previously seen concat.
const HloInstruction* prev_concat;
// The opnd id of the seen concat.
int64 concat_opnd_idx;
// The slice that recovers the opnd in the concat outputs.
const HloInstruction* slice_to_recover_opnd;
};
// Returns an optional concat usage info to denote whether the concat is used in
// an elementwise manner. A concat followed by slices is considered effectively
// elementwise if the slices combinedly is a reverse function of the concat.
absl::optional<ConcatUsageInfo> ConcatIsEffectivelyElementwise(
const HloInstruction& concat, const HloInstruction& operand,
const ConcatUsageInfo& info) {
// First, check if this concat is in the below pattern. Also, we check
// that the slices combinedly are in effect a reverse function of the concat.
//
// Concat
// | |
// v v
// Slice Slice
//
std::vector<HloInstruction*> users = concat.users();
if (!absl::c_all_of(users, Is1dSliceWithoutStrides)) {
// Limit our supported cases to 1 dimensional slices.
return absl::optional<ConcatUsageInfo>();
}
// Verify that each operand to the concat is reversed by a slice.
if (users.size() != concat.operand_count() ||
concat.operand_count() != concat.unique_operands().size()) {
return absl::optional<ConcatUsageInfo>();
}
absl::c_sort(users, [](const HloInstruction* a, const HloInstruction* b) {
return a->slice_starts().at(0) < b->slice_starts().at(0);
});
int64 prev_limit = 0;
for (int64 i = 0; i < users.size(); ++i) {
const HloInstruction* u = users[i];
int64 slice_size = u->slice_limits().at(0) - u->slice_starts().at(0);
if (u->slice_starts().at(0) != prev_limit ||
slice_size != ShapeUtil::ElementsIn(concat.operand(i)->shape())) {
return absl::optional<ConcatUsageInfo>();
}
prev_limit = u->slice_limits().at(0);
}
// If we have seen other concats, make sure they are identical. Multiple
// concats exist because horizontal fusion inserts one concat for each output
// of the fusion candidates. Check that all concats and operand ids are the
// same to know that the "transitive use closure" will be computed in the same
// iteration space.
int64 operand_idx = concat.operand_index(&operand);
if (info.prev_concat != nullptr) {
bool is_concat_identical = info.prev_concat->Identical(
concat,
/*eq_operands=*/[](const HloInstruction*, const HloInstruction*) {
// Operands don't need to be the same.
return true;
});
if (!is_concat_identical || info.concat_opnd_idx != operand_idx) {
return absl::optional<ConcatUsageInfo>();
}
}
const HloInstruction* slice_to_recover_opnd = users.at(operand_idx);
return absl::optional<ConcatUsageInfo>(
ConcatUsageInfo{&concat, operand_idx, slice_to_recover_opnd});
}
// Returns whether we can prove the transitive uses of `param` are in effect
// elementwise. In other words, we prove that the "transitive use closure" will
// all be computed in the same iteration space without any reorder of elements.
// In addition, we check that the "transitive use closure" includes the output
// in the `root_tuple`.
// Theoretically, We can prove more patterns but our primary use case is
// SliceInputFusion.
bool AreTransitiveUsesEffectivelyElementwise(const HloInstruction* param,
const HloInstruction* root_tuple,
const ShapeIndex& out_shape_idx) {
CHECK_EQ(root_tuple->opcode(), HloOpcode::kTuple);
CHECK_EQ(out_shape_idx.size(), 1);
absl::flat_hash_set<const HloInstruction*> visited;
absl::InlinedVector<const HloInstruction*, 4> stack;
stack.push_back(param);
ConcatUsageInfo concat_usage_info{nullptr, 0, nullptr};
bool is_output_reachable = false;
while (!stack.empty()) {
const HloInstruction* current = stack.back();
stack.pop_back();
visited.insert(current);
for (const HloInstruction* user : current->users()) {
VLOG(3) << "Visiting: " << user->ToString();
switch (user->opcode()) {
case HloOpcode::kTuple:
if (user == root_tuple &&
current == root_tuple->operand(out_shape_idx.back())) {
// We need to know if the output is reachable by the `param` to make
// sure that they will be computed in the same iteration space.
is_output_reachable = true;
}
break;
case HloOpcode::kReshape:
if (!ShapeUtil::ReshapeIsBitcast(current->shape(), user->shape())) {
return false;
}
break;
case HloOpcode::kConcatenate: {
absl::optional<ConcatUsageInfo> optional_concat_info =
ConcatIsEffectivelyElementwise(*user, *current,
concat_usage_info);
if (!optional_concat_info) {
return false;
}
concat_usage_info = *optional_concat_info;
// Early continue as we only want to traverse through the slice that
// recovers the operand. It is guaranteed that the operand to the
// concat and the slice have the same iteration space. Insert the
// slice instead of the concat.
CHECK(!visited.contains(concat_usage_info.slice_to_recover_opnd));
stack.push_back(concat_usage_info.slice_to_recover_opnd);
continue;
}
default:
for (const int64 use_index : user->OperandIndices(current)) {
if (!user->IsElementwiseOnOperand(use_index)) {
// Found a user that is non-elementwise on the current
// instruction.
return false;
}
}
if (!LayoutUtil::Equal(current->shape().layout(),
user->shape().layout())) {
// Make sure the layout is not changed by the elementwise op.
return false;
}
break;
} // end of switch
if (!visited.contains(user)) {
stack.push_back(user);
}
}
}
return is_output_reachable;
}
} // namespace
bool HloDataflowAnalysis::ValueIsDefinedAt(const HloInstruction* instruction,
const ShapeIndex& index) const {
const HloValueSet& value_set = GetValueSet(instruction, index);
@ -1266,10 +1435,23 @@ bool HloDataflowAnalysis::CanShareOperandBufferWithUser(
if (operand->opcode() == HloOpcode::kConstant) {
return false;
}
const Shape& operand_subshape =
ShapeUtil::GetSubshape(operand->shape(), operand_index);
const Shape& user_subshape =
ShapeUtil::GetSubshape(user->shape(), user_index);
if (IsSliceInputFusion(*user)) {
HloInstruction* fusion_param =
user->fused_parameter(user->operand_index(operand));
// We don't require the same dimensions but only the same number of elements
// and type (to make sure the same buffer size).
return operand_subshape.IsArray() && user_subshape.IsArray() &&
ShapeUtil::ElementsIn(operand_subshape) ==
ShapeUtil::ElementsIn(user_subshape) &&
ShapeUtil::SameElementType(operand_subshape, user_subshape) &&
AreTransitiveUsesEffectivelyElementwise(
fusion_param, user->fused_expression_root(), user_index);
}
// Check that operand and user emit the same shape and layout.
if (!ShapeUtil::Equal(operand_subshape, user_subshape)) {

View File

@ -2795,5 +2795,150 @@ TEST_F(CanShareOperandBufferWithUserTest, CallToComputationWithFusionRoot) {
dataflow_analysis_->CanShareOperandBufferWithUser(reverse, {}, call, {}));
}
TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceWithElementwise) {
const char* kModule = R"(
HloModule test
fused_computation {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
add0 = f32[10, 20] add(p0, p1)
sub0 = f32[10, 10] subtract(p2, p3)
reshape0 = f32[200] reshape(add0)
reshape1 = f32[100] reshape(sub0)
concat0 = f32[300] concatenate(reshape0, reshape1), dimensions={0}
slice0 = f32[200] slice(concat0), slice={[0:200]}
slice1 = f32[100] slice(concat0), slice={[200:300]}
ROOT tuple = (f32[200], f32[100]) tuple(slice0, slice1)
}
ENTRY test {
p0 = f32[10,20] parameter(0)
p1 = f32[10,20] parameter(1)
p2 = f32[10,10] parameter(2)
p3 = f32[10,10] parameter(3)
ROOT fusion = (f32[200], f32[100]) fusion(p0, p1, p2, p3), kind=kInput, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
auto* fusion = module_->entry_computation()->root_instruction();
auto* param0 = module_->entry_computation()->parameter_instruction(0);
auto* param1 = module_->entry_computation()->parameter_instruction(1);
auto* param2 = module_->entry_computation()->parameter_instruction(2);
auto* param3 = module_->entry_computation()->parameter_instruction(3);
RunAnalysis();
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {0}));
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
fusion, {0}));
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param2, {},
fusion, {1}));
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param3, {},
fusion, {1}));
// Tensors of different sizes cannot share buffer.
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {1}));
}
TEST_F(CanShareOperandBufferWithUserTest, ConcatSliceNegativeTest) {
const char* kModule = R"(
HloModule test
fused_computation {
// p0 has multiple transitive uses fed to concat. So, p0 cannot share
// buffer with outputs because the aliased output could be written before
// all the uses of p0 are finished.
p0 = f32[100] parameter(0)
p1 = f32[100] parameter(1)
add0 = f32[100] add(p0, p1)
concat0 = f32[200] concatenate(p0, add0), dimensions={0}
slice0 = f32[100] slice(concat0), slice={[0:100]}
slice1 = f32[100] slice(concat0), slice={[100:200]}
ROOT tuple = (f32[100], f32[100]) tuple(slice0, slice1)
}
ENTRY test {
p0 = f32[100] parameter(0)
p1 = f32[100] parameter(1)
ROOT fusion = (f32[100], f32[100]) fusion(p0, p1),
kind=kInput, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
auto* fusion = module_->entry_computation()->root_instruction();
auto* param0 = module_->entry_computation()->parameter_instruction(0);
auto* param1 = module_->entry_computation()->parameter_instruction(1);
RunAnalysis();
// p0 cannot share with either fusion{0} or fusion{1}.
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {0}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {1}));
// p1 cannot share with fusion{0} because we're not sure about their
// relationship.
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
fusion, {0}));
// p1 can share with fusion{1} because they will be executed in an
// elementwise manner.
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
fusion, {1}));
}
TEST_F(CanShareOperandBufferWithUserTest, MultipleConcatenates) {
const char* kModule = R"(
HloModule test
fused_computation {
p0 = f32[100] parameter(0)
p1 = f32[100] parameter(1)
add0 = f32[100] add(p0, p1)
sub0 = f32[100] subtract(p1, p1)
concat0 = f32[200] concatenate(p0, add0), dimensions={0}
slice0 = f32[100] slice(concat0), slice={[0:100]}
slice1 = f32[100] slice(concat0), slice={[100:200]}
concat1 = f32[200] concatenate(p0, sub0), dimensions={0}
slice2 = f32[100] slice(concat1), slice={[0:100]}
slice3 = f32[100] slice(concat1), slice={[100:200]}
ROOT tuple = (f32[100], f32[100], f32[100], f32[100])
tuple(slice0, slice1, slice2, slice3)
}
ENTRY test {
p0 = f32[100] parameter(0)
p1 = f32[100] parameter(1)
ROOT fusion = (f32[100], f32[100], f32[100], f32[100])
fusion(p0, p1), kind=kInput, calls=fused_computation
}
)";
TF_ASSERT_OK_AND_ASSIGN(module_, ParseAndReturnVerifiedModule(kModule));
auto* fusion = module_->entry_computation()->root_instruction();
auto* param0 = module_->entry_computation()->parameter_instruction(0);
auto* param1 = module_->entry_computation()->parameter_instruction(1);
RunAnalysis();
// p0 cannot share.
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {0}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {1}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {2}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param0, {},
fusion, {3}));
// p1 can share with either fusion{1} or fusion{3}.
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
fusion, {1}));
EXPECT_TRUE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
fusion, {3}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
fusion, {0}));
EXPECT_FALSE(dataflow_analysis_->CanShareOperandBufferWithUser(param1, {},
fusion, {2}));
}
} // namespace
} // namespace xla

View File

@ -1933,36 +1933,19 @@ Status LayoutAssignment::RunOnComputation(
// Copy the root instruction's result if its layout does not match the result
// layout constraint.
if (constraints.ResultLayout() != nullptr) {
// Layout assignment at this point only does minor-to-major assignment so
// tiling info should be ignored here for comparison.
if (!constraints.ResultLayout()->MatchesLayoutInShape(
computation->root_instruction()->shape(),
/*minor_to_major_only=*/true)) {
if (conditional_mismatch_.count(computation) > 0) {
*FindOrDie(computation_layouts_, computation).mutable_result_layout() =
FindOrDie(conditional_mismatch_, computation).result_layout();
}
TF_ASSIGN_OR_RETURN(
HloInstruction * new_root,
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
computation->root_instruction()));
computation->set_root_instruction(new_root);
} else {
// Copy the specified tiling info.
auto assign_tiling = [&constraints](xla::Shape* subshape,
const xla::ShapeIndex& index) {
if (subshape->IsArray()) {
const Shape& result_shape = ShapeUtil::GetSubshape(
constraints.ResultLayout()->shape(), index);
subshape->mutable_layout()->mutable_tiles()->assign(
result_shape.layout().tiles().begin(),
result_shape.layout().tiles().end());
}
};
xla::ShapeUtil::ForEachMutableSubshape(
computation->root_instruction()->mutable_shape(), assign_tiling);
if (constraints.ResultLayout() != nullptr &&
!constraints.ResultLayout()->MatchesLayoutInShape(
computation->root_instruction()->shape(),
/*minor_to_major_only=*/true)) {
if (conditional_mismatch_.count(computation) > 0) {
*FindOrDie(computation_layouts_, computation).mutable_result_layout() =
FindOrDie(conditional_mismatch_, computation).result_layout();
}
TF_ASSIGN_OR_RETURN(
HloInstruction * new_root,
CreateCopyWithNewLayout(constraints.ResultLayout()->shape(),
computation->root_instruction()));
computation->set_root_instruction(new_root);
}
return Status::OK();
}

View File

@ -441,8 +441,7 @@ absl::flat_hash_map<int, HloInstruction*> CreateSinkedAllReduces(
}
CHECK(ContainsKey(tuple_index_to_old_buffer, tuple_index));
HloInstruction* old_buffer = tuple_index_to_old_buffer.at(tuple_index);
CHECK(Shape::Equal().IgnoreLayout()(old_buffer->shape(),
all_reduced_delta->shape()));
CHECK(ShapeUtil::Equal(old_buffer->shape(), all_reduced_delta->shape()));
HloInstruction* add_to_old_buffer =
while_parent->AddInstruction(HloInstruction::CreateBinary(
all_reduced_delta->shape(), HloOpcode::kAdd, old_buffer,

View File

@ -618,8 +618,8 @@ optional<int64> ComputeWhileLoopTripCountUpperBound(HloInstruction* while_op) {
}
Literal cond_result_pred = std::move(eval_result.ValueOrDie());
CHECK(Shape::Equal().IgnoreLayout()(cond_result_pred.shape(),
ShapeUtil::MakeShape(PRED, {})));
CHECK(ShapeUtil::Equal(cond_result_pred.shape(),
ShapeUtil::MakeShape(PRED, {})));
// Per the explanation above, if the evaluated condition returns false, the
// loop executes at most once.

View File

@ -314,7 +314,10 @@ message DebugOptions {
// Compilation errors out if these ops are encountered.
bool xla_gpu_deterministic_ops = 148;
// Next id: 150
// Paths to files with LLVM code.
repeated string xla_gpu_llvm_ir_file = 150;
// Next id: 151
// Extra options to pass to the compilation backend (e.g. LLVM); specific
// interpretation of these values is left to the backend.

View File

@ -602,7 +602,6 @@ cc_library(
"//tensorflow/core/kernels:batch_kernels",
"//tensorflow/core/kernels:bincount_op",
"//tensorflow/core/kernels:boosted_trees_ops",
"//tensorflow/core/kernels:tensor_forest_ops",
"//tensorflow/core/kernels:candidate_sampler_ops",
"//tensorflow/core/kernels:checkpoint_ops",
"//tensorflow/core/kernels:clustering_ops",
@ -994,7 +993,6 @@ filegroup(
"stateless_random_ops_v2_op_lib",
"string_ops_op_lib",
"summary_ops_op_lib",
"tensor_forest_ops_op_lib",
"tpu_configuration_ops_op_lib",
"tpu_cross_replica_ops_op_lib",
"tpu_embedding_ops_op_lib",

View File

@ -1,17 +0,0 @@
op {
graph_op_name: "TensorForestCreateTreeVariable"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource to be created.
END
}
in_arg {
name: "tree_config"
description: <<END
Serialized proto string of the boosted_trees.Tree.
END
}
summary: "Creates a tree resource and returns a handle to it."
}

View File

@ -1,17 +0,0 @@
op {
graph_op_name: "TensorForestTreeDeserialize"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource to be restored.
END
}
in_arg {
name: "tree_config"
description: <<END
Serialied proto string of the boosted_trees.Tree proto.
END
}
summary: "Deserializes a proto into the tree handle"
}

View File

@ -1,17 +0,0 @@
op {
graph_op_name: "TensorForestTreeIsInitializedOp"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree.
END
}
out_arg {
name: "is_initialized"
description: <<END
Whether the tree is initialized.
END
}
summary: "Checks whether a tree has been initialized."
}

View File

@ -1,29 +0,0 @@
op {
graph_op_name: "TensorForestTreePredict"
visibility: HIDDEN
attr {
name: "logits_dimension"
description: <<END
Scalar, dimension of the logits.
END
}
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource.
END
}
in_arg {
name: "dense_features"
description: <<END
Rank 2 dense features tensor.
END
}
out_arg {
name: "logits"
description: <<END
The logits predictions from the tree for each instance in the batch.
END
}
summary: "Output the logits for the given input data"
}

View File

@ -1,5 +0,0 @@
op {
graph_op_name: "TensorForestTreeResourceHandleOp"
visibility: HIDDEN
summary: "Creates a handle to a TensorForestTreeResource"
}

View File

@ -1,17 +0,0 @@
op {
graph_op_name: "TensorForestTreeSerialize"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource to be serialized.
END
}
out_arg {
name: "tree_config"
description: <<END
Serialied proto string of the tree resource.
END
}
summary: "Serializes the tree handle to a proto"
}

View File

@ -1,17 +0,0 @@
op {
graph_op_name: "TensorForestTreeSize"
visibility: HIDDEN
in_arg {
name: "tree_handle"
description: <<END
Handle to the tree resource.
END
}
out_arg {
name: "tree_size"
description: <<END
The size of the tree.
END
}
summary: "Get the number of nodes in a tree"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TensorForestCreateTreeVariable"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TensorForestTreeDeserialize"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TensorForestTreeIsInitializedOp"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TensorForestTreePredict"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TensorForestTreeResourceHandleOp"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TensorForestTreeSerialize"
}

View File

@ -1,3 +0,0 @@
op {
graph_op_name: "TensorForestTreeSize"
}

View File

@ -191,19 +191,33 @@ Status ResourceMgr::DoCreate(const string& container, TypeIndex type,
type.name());
}
Status ResourceMgr::Lookup(const ResourceHandle& handle,
ResourceBase** resource) const {
tf_shared_lock l(mu_);
return DoLookup(handle.container(), handle.hash_code(),
/*type_name=*/"ResourceBase", handle.name(), resource);
}
Status ResourceMgr::DoLookup(const string& container, TypeIndex type,
const string& name,
ResourceBase** resource) const {
return DoLookup(container, type.hash_code(), type.name(), name, resource);
}
Status ResourceMgr::DoLookup(const string& container, uint64 type_hash_code,
const string& type_name,
const string& resource_name,
ResourceBase** resource) const {
const Container* b = gtl::FindPtrOrNull(containers_, container);
if (b == nullptr) {
return errors::NotFound("Container ", container,
" does not exist. (Could not find resource: ",
container, "/", name, ")");
container, "/", resource_name, ")");
}
auto iter = b->find({type.hash_code(), name});
auto iter = b->find({type_hash_code, resource_name});
if (iter == b->end()) {
return errors::NotFound("Resource ", container, "/", name, "/", type.name(),
" does not exist.");
return errors::NotFound("Resource ", container, "/", resource_name, "/",
type_name, " does not exist.");
}
*resource = const_cast<ResourceBase*>(iter->second.resource.get());
(*resource)->Ref();
@ -326,6 +340,12 @@ Status HandleFromInput(OpKernelContext* ctx, StringPiece input,
return Status::OK();
}
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
ResourceBase** value) {
TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
return ctx->resource_manager()->Lookup(p, value);
}
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
TF_RETURN_IF_ERROR(internal::ValidateDevice(ctx, p));
return ctx->resource_manager()->Delete(p);

View File

@ -181,6 +181,13 @@ class ResourceMgr {
Status Lookup(const std::string& container, const std::string& name,
T** resource) const TF_MUST_USE_RESULT;
// If the resource manager has a resource matching "handle", returns it in
// "*resource" and the caller takes the ownership of one ref on "*resource".
//
// REQUIRES: resource != nullptr
Status Lookup(const ResourceHandle& handle,
ResourceBase** resource) const TF_MUST_USE_RESULT;
// Similar to Lookup, but looks up multiple resources at once, with only a
// single lock acquisition. If containers_and_names[i] is uninitialized
// then this function does not modify resources[i].
@ -260,6 +267,9 @@ class ResourceMgr {
Status LookupInternal(const std::string& container, const std::string& name,
T** resource) const
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
Status LookupInternal(const std::string& container, uint64 type_hash_code,
const std::string& name, ResourceBase** resource) const
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
Status DoCreate(const std::string& container, TypeIndex type,
const std::string& name, ResourceBase* resource)
@ -268,6 +278,11 @@ class ResourceMgr {
Status DoLookup(const std::string& container, TypeIndex type,
const std::string& name, ResourceBase** resource) const
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
Status DoLookup(const std::string& container, uint64 type_hash_code,
const std::string& type_name,
const std::string& resource_name,
ResourceBase** resource) const
TF_SHARED_LOCKS_REQUIRED(mu_) TF_MUST_USE_RESULT;
Status DoDelete(const std::string& container, uint64 type_hash_code,
const std::string& resource_name,
@ -733,12 +748,17 @@ Status ValidateDeviceAndType(OpKernelContext* ctx, const ResourceHandle& p) {
} // namespace internal
// Creates the resource pointed at by "p". The caller transfers the ownership of
// one ref on "*value" to the resource manager in "ctx", regardless of whether
// this operation succeeds or fails.
template <typename T>
Status CreateResource(OpKernelContext* ctx, const ResourceHandle& p, T* value) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
return ctx->resource_manager()->Create(p.container(), p.name(), value);
}
// If the resource manager in "ctx" has a resource matching "p", returns it in
// "*value" and the caller takes the ownership of one ref on "*value"
template <typename T, bool use_dynamic_cast>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value) {
@ -747,6 +767,13 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
p.name(), value);
}
// If the resource manager in "ctx" has a resource matching "p", returns it in
// "*value" and the caller takes the ownership of one ref on "*value"
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
ResourceBase** value);
// If the resource manager in "ctx" has a resource matching "p", returns it in
// "*value".
template <typename T>
Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value) {
@ -757,6 +784,8 @@ Status LookupResource(OpKernelContext* ctx, const ResourceHandle& p,
return Status::OK();
}
// Similar to Lookup, but looks up multiple resources at once, with only a
// single lock acquisition.
template <typename T>
Status LookupResources(OpKernelContext* ctx,
absl::Span<ResourceHandle const* const> p,
@ -770,6 +799,13 @@ Status LookupResources(OpKernelContext* ctx,
return ctx->resource_manager()->LookupMany(containers_and_names, values);
}
// If the resource manager in "ctx" has a resource pointed at by "p", returns
// it in "*value". Otherwise, invokes creator() to create the resource.
// The caller takes the ownership of one ref on "*value".
//
// WARNING: creator() must not call any methods on the resource manager during
// its execution, because a non-reentrant lock is held during the creator() call
// in order to guarantee atomicity of LookupOrCreateResource().
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
T** value, std::function<Status(T**)> creator) {
@ -778,6 +814,12 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
creator);
}
// If the resource manager in "ctx" has a resource pointed at by "p", returns
// it in "*value". Otherwise, invokes creator() to create the resource.
//
// WARNING: creator() must not call any methods on the resource manager during
// its execution, because a non-reentrant lock is held during the creator() call
// in order to guarantee atomicity of LookupOrCreateResource().
template <typename T>
Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
core::RefCountPtr<T>* value,
@ -789,12 +831,14 @@ Status LookupOrCreateResource(OpKernelContext* ctx, const ResourceHandle& p,
return Status::OK();
}
// Deletes the resource pointed by "p", using the resource manager in "ctx".
template <typename T>
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p) {
TF_RETURN_IF_ERROR(internal::ValidateDeviceAndType<T>(ctx, p));
return ctx->resource_manager()->Delete<T>(p.container(), p.name());
}
// Deletes the resource pointed by "p", using the resource manager in "ctx".
Status DeleteResource(OpKernelContext* ctx, const ResourceHandle& p);
template <typename T>

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/refcount.h"
#include "tensorflow/core/platform/regexp.h"
#include "tensorflow/core/platform/test.h"
@ -285,6 +286,36 @@ TEST(ResourceHandleTest, CRUD) {
}
}
TEST(ResourceHandleTest, LookupDeleteGenericResource) {
ResourceMgr resource_mgr("");
OpKernelContext::Params params;
params.resource_manager = &resource_mgr;
StubDevice device("device_name");
params.device = &device;
OpKernelContext ctx(&params, 0);
ResourceHandle p =
MakeResourceHandle<StubResource>(&ctx, "container", "name");
{
auto* r = new StubResource();
r->value_ = 42;
TF_EXPECT_OK(CreateResource(&ctx, p, r));
}
{
ResourceBase* r;
TF_ASSERT_OK(LookupResource(&ctx, p, &r));
ASSERT_TRUE(r != nullptr);
core::ScopedUnref unref(r);
EXPECT_EQ(static_cast<StubResource*>(r)->value_, 42);
}
{
TF_EXPECT_OK(DeleteResource(&ctx, p));
ResourceBase* unused;
EXPECT_FALSE(LookupResource(&ctx, p, &unused).ok());
}
}
TEST(ResourceHandleTest, DifferentDevice) {
ResourceMgr resource_mgr("");
OpKernelContext::Params params;

View File

@ -7240,13 +7240,6 @@ tf_kernel_library(
],
)
tf_kernel_library(
name = "tensor_forest_ops",
deps = [
"//tensorflow/core/kernels/tensor_forest:tensor_forest_ops",
],
)
tf_kernel_library(
name = "data_service_ops",
deps = [

View File

@ -1,141 +0,0 @@
# Description:
# quantization-specific OpKernels for hexagon
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_test",
)
load("//tensorflow:tensorflow.bzl", "tf_kernel_library")
package(
default_visibility = ["//visibility:public"],
licenses = ["notice"], # Apache 2.0
)
tf_cc_test(
name = "graph_transferer_test",
size = "small",
srcs = [
"graph_transferer_test.cc",
"hexagon_graph_execution_test.cc",
],
data = ["//tensorflow/core/example:example_parser_configuration_testdata"],
deps = [
":graph_transferer",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:array_ops_op_lib",
"//tensorflow/core:bitwise_ops_op_lib",
"//tensorflow/core:core_cpu",
"//tensorflow/core:direct_session",
"//tensorflow/core:functional_ops_op_lib",
"//tensorflow/core:lib",
"//tensorflow/core:math_ops_op_lib",
"//tensorflow/core:mkl_nn_ops_op_lib",
"//tensorflow/core:nn_ops_op_lib",
"//tensorflow/core:ops",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:remote_fused_graph_ops_op_lib",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:cwise_op",
"//tensorflow/core/kernels:quantization_utils",
"//tensorflow/core/kernels:quantized_ops",
"//tensorflow/core/kernels:reduction_ops",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/core/kernels:remote_fused_graph_ops",
"//tensorflow/core/kernels:reshape_op",
"//tensorflow/core/kernels:softmax_op",
"@com_google_absl//absl/base",
],
)
tf_kernel_library(
name = "graph_transferer",
srcs = [
"graph_transfer_utils.cc",
"graph_transferer.cc",
"hexagon_control_wrapper.cc",
"hexagon_ops_definitions.cc",
"soc_interface.cc",
],
hdrs = [
"graph_transfer_utils.h",
"graph_transferer.h",
"hexagon_control_wrapper.h",
"hexagon_ops_definitions.h",
"soc_interface.h",
],
deps = [
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:remote_fused_graph_ops",
"//tensorflow/cc:scope",
"//tensorflow/core",
"//tensorflow/core:core_cpu",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//third_party/eigen3",
],
)
cc_library(
name = "hexagon_rewriter_transform",
srcs = [
"hexagon_rewriter_transform.cc",
],
deps = [
":graph_transferer",
"//tensorflow/cc:cc_ops",
"//tensorflow/cc:remote_fused_graph_ops",
"//tensorflow/tools/graph_transforms:transform_utils",
],
alwayslink = 1,
)
tf_cc_test(
name = "hexagon_rewriter_transform_test",
size = "small",
srcs = ["hexagon_rewriter_transform_test.cc"],
deps = [
":graph_transferer",
":hexagon_rewriter_transform",
"//tensorflow/cc:cc_ops",
"//tensorflow/core:core_cpu",
"//tensorflow/core:core_cpu_internal",
"//tensorflow/core:framework",
"//tensorflow/core:tensorflow",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
"//tensorflow/tools/graph_transforms:transform_utils",
],
)
cc_library(
name = "hexagon_remote_fused_graph_executor_build",
srcs = [
"hexagon_remote_fused_graph_executor_build.cc",
],
deps = [
":graph_transferer",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
],
alwayslink = 1,
)
tf_cc_test(
name = "hexagon_remote_fused_graph_executor_build_test",
size = "small",
srcs = ["hexagon_remote_fused_graph_executor_build_test.cc"],
deps = [
":hexagon_remote_fused_graph_executor_build",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
"//tensorflow/core:testlib",
"//tensorflow/core/kernels:remote_fused_graph_execute_utils",
],
)

View File

@ -1,167 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/cc/framework/scope.h"
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
// function alias
constexpr auto AddOutputTensorShapeTypeByTensorShapeMap =
&RemoteFusedGraphExecuteUtils::AddOutputTensorShapeTypeByTensorShapeMap;
/* static */ std::priority_queue<std::tuple<float, int, string>>
GraphTransferUtils::GetTopNFloatResults(const float* const data,
const string* const labels,
const int element_count) {
CHECK(data != nullptr);
CHECK(labels != nullptr);
std::priority_queue<std::tuple<float, int, string>> queue;
for (int i = 0; i < element_count; ++i) {
queue.emplace(data[i], i, labels[i]);
}
return queue;
}
/* static */ void GraphTransferUtils::DumpTopNFloatResults(
const float* const data, const string* const labels,
const int element_count, const int top_n) {
std::priority_queue<std::tuple<float, int, string>> queue =
GetTopNFloatResults(data, labels, element_count);
LOG(INFO) << "=== Dump ranking ===";
for (int i = 0; i < top_n; ++i) {
const std::tuple<float, int, string>& entry = queue.top();
LOG(INFO) << i << ": " << std::get<1>(entry) << ", " << std::get<2>(entry)
<< ", " << std::get<0>(entry);
queue.pop();
}
}
/* static */ RemoteFusedGraphExecuteInfo
GraphTransferUtils::BuildRemoteFusedGraphExecuteInfo(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs,
const RemoteFusedGraphExecuteUtils::TensorShapeMap& tensor_shape_map) {
RemoteFusedGraphExecuteInfo execute_info;
execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
// copy graph
*execute_info.mutable_remote_graph() = graph_def;
for (const std::pair<string, Tensor>& input : inputs) {
execute_info.add_graph_input_node_name(input.first);
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
*execute_info.add_default_graph_input_tensor_shape();
tensor_shape_type.set_dtype(input.second.dtype());
TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
for (const int64 dim : input.second.shape().dim_sizes()) {
tensor_shape_proto.add_dim()->set_size(dim);
}
}
for (const string& output_name : outputs) {
const std::pair<DataType, TensorShape>* tensor_shape_type =
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
output_name);
CHECK_NOTNULL(tensor_shape_type);
execute_info.add_graph_output_node_name(output_name);
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type_proto =
*execute_info.add_default_graph_output_tensor_shape();
tensor_shape_type_proto.set_dtype(tensor_shape_type->first);
TensorShapeProto& tensor_shape_proto =
*tensor_shape_type_proto.mutable_shape();
for (const int64 dim : tensor_shape_type->second.dim_sizes()) {
tensor_shape_proto.add_dim()->set_size(dim);
}
}
return execute_info;
}
/* static */ GraphDef GraphTransferUtils::BuildFusedGraphDef(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs, GraphDef* original_def) {
RemoteFusedGraphExecuteUtils::TensorShapeMap tensor_shape_map;
Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
*original_def, inputs, true /* initialize_by_zero */, &tensor_shape_map);
for (NodeDef& node_def : *original_def->mutable_node()) {
TF_CHECK_OK(
AddOutputTensorShapeTypeByTensorShapeMap(tensor_shape_map, &node_def));
}
CHECK(status.ok());
Scope root = Scope::NewRootScope();
std::vector<Output> output_list;
DataTypeVector input_types;
for (const std::pair<string, Tensor>& input_node_info : inputs) {
const Scope& scope = root.WithOpName(input_node_info.first);
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("Placeholder");
auto builder = NodeBuilder(unique_name, "Placeholder")
.Attr("dtype", input_node_info.second.dtype())
.Attr("shape", input_node_info.second.shape());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
TF_CHECK_OK(scope.status());
output_list.emplace_back(Output(ret, 0));
input_types.push_back(input_node_info.second.dtype());
}
const RemoteFusedGraphExecuteInfo execute_info =
BuildRemoteFusedGraphExecuteInfo(*original_def, inputs, outputs,
tensor_shape_map);
DataTypeVector output_types;
// Sanity-check to confirm all output data types are same.
for (const string& output_node_name : outputs) {
const std::pair<DataType, TensorShape>* tst =
RemoteFusedGraphExecuteUtils::GetTensorShapeType(tensor_shape_map,
output_node_name);
CHECK_NE(tst, nullptr);
output_types.push_back(tst->first);
}
const Scope& scope = root.WithOpName(remote_graph_execute_name);
CHECK(scope.ok());
auto node_out_list = ops::AsNodeOutList(scope, InputList(output_list));
Node* node;
const auto unique_name = scope.GetUniqueNameForOp("RemoteFusedGraphExecute");
auto builder = NodeBuilder(unique_name, "RemoteFusedGraphExecute")
.Input(node_out_list)
.Attr("Tinputs", input_types)
.Attr("Toutputs", output_types)
.Attr("serialized_remote_fused_graph_execute_info",
StringPiece(execute_info.SerializeAsString()));
CHECK(scope.ok());
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &node));
CHECK(scope.ok()) << scope.status();
GraphDef fusedGraphDef;
TF_CHECK_OK(root.ToGraphDef(&fusedGraphDef));
return fusedGraphDef;
}
} // namespace tensorflow

View File

@ -1,59 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_
#include <queue>
#include <utility>
#include <vector>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
class RemoteFusedGraphExecuteInfo;
class GraphTransferUtils {
public:
static std::priority_queue<std::tuple<float, int, string>>
GetTopNFloatResults(const float* const data, const string* const labels,
const int element_count);
static void DumpTopNFloatResults(const float* const data,
const string* const labels,
const int element_count, const int top_n);
static GraphDef BuildFusedGraphDef(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& remote_graph_execute_name,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs, GraphDef* original_def);
private:
static RemoteFusedGraphExecuteInfo BuildRemoteFusedGraphExecuteInfo(
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& inputs,
const std::vector<string>& outputs,
const RemoteFusedGraphExecuteUtils::TensorShapeMap& tensor_shape_map);
TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferUtils);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFER_UTILS_H_

File diff suppressed because it is too large Load Diff

View File

@ -1,231 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
vcyou may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
#define TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_
#include <array>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/core/common_runtime/shape_refiner.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/util/padding.h"
namespace tensorflow {
class GraphTransferInfo;
class GraphTransferNodeInfo;
class GraphTransferNodeInputInfo;
// GraphTransferer transfers graph definitions into SoC memory.
// This functionality is effective if SoC is capable to run
// the graph on that chip.
// TODO(satok): support transferring subgraphs to be able to split graphs
// to avoid unsupported ops in SoC.
class GraphTransferer {
public:
// TODO(satok): Remove. Use proto definition instead.
static constexpr int MAX_SUPPORTED_RANK = 4;
// TODO(satok): Remove. Use proto definition instead.
static constexpr int SHAPE_ARRAY_SIZE = MAX_SUPPORTED_RANK;
using TensorShapeMap = RemoteFusedGraphExecuteUtils::TensorShapeMap;
GraphTransferer();
~GraphTransferer();
// Load graph structure into GraphTransferer
// TODO(satok): Pass a pair of TensorShape and DataType instead of
// Tensor as input_node_info_list.
Status LoadGraphFromProto(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const GraphDef& graph_def,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names,
const bool shape_inference_for_unknown_shape);
// Load graph structure into GraphTransferer from protobuf file
// TODO(satok): Pass a pair of TensorShape and DataType instead of
// Tensor as input_node_info_list.
Status LoadGraphFromProtoFile(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const string& graph_def_path,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names, const bool is_text_proto,
const bool shape_inference_for_unknown_shape,
const bool dry_run_for_unknown_shape);
// Sort params so that all input nodes appear before consumer nodes.
// CAVEAT: This may be slow if the number of nodes are too large
void SortParams(const std::vector<string>& output_node_names);
void EnableStrictCheckMode(bool enable);
// Import parameters for transfer
void SetSerializedGraphTransferInfo(const string& serialized_proto);
// Return parameters for graph transfer
const GraphTransferInfo& GetGraphTransferInfo() const;
// Return mutable GraphTransferInfo for graph transfer
GraphTransferInfo& GetMutableGraphTransferInfo();
// Dump verification string of parameters to verify with offline tools
void DumpVerificationStringOfNodeTransferParams() const;
static std::array<int64, SHAPE_ARRAY_SIZE> ToTensorShapeArray(
const TensorShape& shape);
private:
class TransferParamsComparator {
public:
TransferParamsComparator(
const std::unordered_map<int, std::unordered_set<int>>& dep_map);
bool operator()(const GraphTransferNodeInfo& obj0,
const GraphTransferNodeInfo& obj1);
const std::unordered_map<int, std::unordered_set<int>>& dependency_map_;
};
void CacheNode(const Node& node);
bool AreAllInputsCached(const Node& node) const;
// Transform a remote fused graph to add an aggregated input node which takes
// all inputs of the remote graph.
Status TransformGraphToAddAggregatedInputNode(
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
Graph* graph, ShapeRefiner* shape_refiner);
Status RegisterNode(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names);
void RegisterConstantNode(const ShapeRefiner& shape_refiner,
const Node& node);
int RegisterConstantShape(const std::vector<int>& shape);
int RegisterConstTensor(const Tensor& tensor, const string& suffix);
int RegisterConstScalar(const DataType dt, const int val, const int dst_id,
const int dst_input_count);
bool HasPaddingAndStrides(const Node& node);
bool NeedsToAddRank(const Node& node);
bool IsPadNode(const Node& node);
// Return true if the node is a reshape op which just flattens input
// TODO(satok): Remove this method once generic reshape op is implemented in
// SOC
bool IsNodeFlattenReshape(const Node& node,
const ShapeRefiner& shape_refiner);
void RegisterNodeWithPaddingAndStrides(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterNodeWithRank(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterPadNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterInputNode(const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterFlattenNode(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
void RegisterGenericNode(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node);
Status RegisterNodeIfAllInputsAreCached(
const IRemoteFusedGraphOpsDefinitions& ops_definitions,
const ShapeRefiner& shape_refiner, const Node& node,
const bool only_register_const_node,
const std::vector<std::pair<string, Tensor>>& input_node_info_list,
const std::vector<string>& output_node_names);
void AppendNodeParams(const string& name, const int id, const string& type,
const int type_id, const int padding,
const int inputs_size,
const std::vector<int>& extra_inputs,
const int outputs_size);
void AddNodeInputByInputIndex(const Node& node, const int idx,
GraphTransferNodeInputInfo* node_input_info);
void AppendNodeInputParams(const int id, const Node& node,
const std::vector<int>& extra_inputs);
void AppendNodeOutputParams(const ShapeRefiner& shape_refiner, const int id,
const Node& node);
static std::array<int64, SHAPE_ARRAY_SIZE> BuildShapeArray(
const shape_inference::ShapeHandle& shape_handle,
shape_inference::InferenceContext* context);
void AppendNodeParamsWithIoParams(
const ShapeRefiner& shape_refiner, const Node& node, const string& name,
const int id, const string& type, const int type_id, const int padding,
const int inputs_size, const std::vector<int>& extra_inputs,
const int outputs_size, const bool append_input_params,
const bool append_output_params);
static string ToPaddingDebugString(int padding);
// Create dependency map
static void FillDependencyRec(
int node_id, std::unordered_map<int, std::unordered_set<int>>& dep_map,
std::unordered_set<int>& completed);
// Build tensor from proto
static Status MakeTensorFromProto(const TensorProto& tensor_proto,
Tensor* tensor);
void ClearCache();
// Dump pretty print of parameters
void DumpNodeTransferParams() const;
GraphTransferInfo* graph_transfer_info_;
std::vector<const Node*> node_name_cache_list_{};
std::unordered_map<string, int> node_name_to_id_cache_map_{};
// strict check mode is true by default. Disable this if the ops' shape
// inferences are not implemented correctly.
bool strict_check_mode_{true};
TF_DISALLOW_COPY_AND_ASSIGN(GraphTransferer);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_GRAPH_TRANSFERER_H_

View File

@ -1,481 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <memory>
#include "tensorflow/cc/ops/const_op.h"
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/graph/graph_def_builder.h"
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
const string NAME_A = "a";
const string NAME_B = "b";
const string NAME_A_PLUS_B = "a_plus_b";
constexpr float NODE_A_VAL = 2.0f;
constexpr float NODE_B_VAL = 3.0f;
constexpr float VALUE_TOLERANCE_FLOAT = 1e-8f;
class GraphTransfererTest : public ::testing::Test {
protected:
void SetUp() final {}
GraphTransferer gt_;
};
const RemoteFusedGraphExecuteUtils::TensorShapeMap EMPTY_OUTPUT_TENSOR_MAP;
class TestGraphTransferOpsDefinitions : public IRemoteFusedGraphOpsDefinitions {
public:
int GetTotalOpsCount() const final { return op_types_.size(); }
int GetOpIdFor(const string& op_type, const DataTypeVector&) const final {
for (int i = 0; i < op_types_.size(); ++i) {
if (op_types_[i] == op_type) {
return i;
}
}
return -1;
}
private:
const std::vector<string> op_types_{"INPUT", "OUTPUT", "Conv2D",
"MaxPool", "NoOp", "Add",
"Const", "Softmax", "Identity"};
} TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
static Output BuildAddOps(const Scope& scope, const Input& x, const Input& y) {
EXPECT_TRUE(scope.ok());
auto _x = ops::AsNodeOut(scope, x);
EXPECT_TRUE(scope.ok());
auto _y = ops::AsNodeOut(scope, y);
EXPECT_TRUE(scope.ok());
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("Add");
auto builder = NodeBuilder(unique_name, "Add").Input(_x).Input(_y);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
EXPECT_TRUE(scope.ok());
return Output(ret, 0);
}
static Output BuildSoftmaxOps(const Scope& scope, const Input& logits) {
EXPECT_TRUE(scope.ok());
auto _logits = ops::AsNodeOut(scope, logits);
EXPECT_TRUE(scope.ok());
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("Softmax");
auto builder = NodeBuilder(unique_name, "Softmax").Input(_logits);
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
EXPECT_TRUE(scope.ok());
return Output(ret, 0);
}
static Output BuildConv2DOps(const Scope& scope, const Input& input,
const Input& filter,
const gtl::ArraySlice<int>& strides,
const StringPiece& padding) {
EXPECT_TRUE(scope.ok());
auto _input = ops::AsNodeOut(scope, input);
EXPECT_TRUE(scope.ok());
auto _filter = ops::AsNodeOut(scope, filter);
EXPECT_TRUE(scope.ok());
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("Conv2D");
auto builder = NodeBuilder(unique_name, "Conv2D")
.Input(_input)
.Input(_filter)
.Attr("strides", strides)
.Attr("use_cudnn_on_gpu", true)
.Attr("padding", padding)
.Attr("data_format", "NHWC");
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
EXPECT_TRUE(scope.ok());
return Output(ret, 0);
}
static Output BuildMaxPoolOps(const Scope& scope, const Input& input,
const gtl::ArraySlice<int>& ksize,
const gtl::ArraySlice<int>& strides,
const StringPiece& padding) {
EXPECT_TRUE(scope.ok());
auto _input = ops::AsNodeOut(scope, input);
EXPECT_TRUE(scope.ok());
Node* ret;
const auto unique_name = scope.GetUniqueNameForOp("MaxPool");
auto builder = NodeBuilder(unique_name, "MaxPool")
.Input(_input)
.Attr("ksize", ksize)
.Attr("strides", strides)
.Attr("padding", padding)
.Attr("data_format", "NHWC");
scope.UpdateBuilder(&builder);
scope.UpdateStatus(builder.Finalize(scope.graph(), &ret));
EXPECT_TRUE(scope.ok());
return Output(ret, 0);
}
static GraphDef CreateAddGraphDef() {
Scope root = Scope::NewRootScope();
Output node_a = ops::Const(root.WithOpName(NAME_A), NODE_A_VAL);
Output node_b = ops::Const(root.WithOpName(NAME_B), NODE_B_VAL);
Output node_add = BuildAddOps(root.WithOpName(NAME_A_PLUS_B), node_a, node_b);
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
static GraphDef CreateConvGraphDef() {
Scope root = Scope::NewRootScope();
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
test::FillIota<float>(&input_data, 1.0f);
Output input =
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
test::FillIota<float>(&filter_data, 1.0f);
Output filter =
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data));
const std::vector<int> strides{1, 1, 1, 1};
Output conv =
BuildConv2DOps(root.WithOpName("conv"), input, filter, strides, "SAME");
Output softmax = BuildSoftmaxOps(root.WithOpName("softmax"), conv);
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
static GraphDef CreatePoolGraphDef() {
Scope root = Scope::NewRootScope();
Tensor input_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
test::FillIota<float>(&input_data, 1.0f);
Output input =
ops::Const(root.WithOpName("input"), Input::Initializer(input_data));
Tensor filter_data(DT_FLOAT, TensorShape({1, 1, 1, 1}));
test::FillIota<float>(&filter_data, 1.0f);
Output filter =
ops::Const(root.WithOpName("filter"), Input::Initializer(filter_data));
const std::vector<int> ksize{1, 1, 1, 1};
const std::vector<int> padding{0, 0, 0, 0};
const std::vector<int> strides{1, 1, 1, 1};
Output max_pool = BuildMaxPoolOps(root.WithOpName("maxpool"), input, ksize,
strides, "SAME");
Output softmax = BuildSoftmaxOps(root.WithOpName("softmax"), max_pool);
GraphDef def;
TF_CHECK_OK(root.ToGraphDef(&def));
return def;
}
static const GraphTransferConstNodeInfo* FindConstNodeInfo(
const GraphTransferer& gt, const string& name) {
for (const GraphTransferConstNodeInfo& params :
gt.GetGraphTransferInfo().const_node_info()) {
if (params.name() == name) {
return &params;
}
}
return nullptr;
}
static const GraphTransferNodeInfo* FindNodeInfo(const GraphTransferer& gt,
const string& name) {
for (const GraphTransferNodeInfo& params :
gt.GetGraphTransferInfo().node_info()) {
if (params.name() == name) {
return &params;
}
}
return nullptr;
}
static const GraphTransferNodeInputInfo* FindNodeInputInfo(
const GraphTransferer& gt, const int node_id) {
for (const GraphTransferNodeInputInfo& params :
gt.GetGraphTransferInfo().node_input_info()) {
if (params.node_id() == node_id) {
return &params;
}
}
return nullptr;
}
static const GraphTransferNodeOutputInfo* FindNodeOutputInfo(
const GraphTransferer& gt, const int node_id) {
for (const GraphTransferNodeOutputInfo& params :
gt.GetGraphTransferInfo().node_output_info()) {
if (params.node_id() == node_id) {
return &params;
}
}
return nullptr;
}
static void SanityCheckNodes(const GraphTransferer& gt) {
for (const GraphTransferNodeInfo& params :
gt.GetGraphTransferInfo().node_info()) {
if (params.input_count() > 0) {
const GraphTransferNodeInputInfo* input_params =
FindNodeInputInfo(gt, params.node_id());
ASSERT_NE(nullptr, input_params);
EXPECT_EQ(params.input_count(), input_params->node_input_size());
EXPECT_EQ(params.node_id(), input_params->node_id());
for (const GraphTransferNodeInput& node_input :
input_params->node_input()) {
EXPECT_GE(node_input.output_port(), 0);
}
}
if (params.output_count() > 0) {
const GraphTransferNodeOutputInfo* output_params =
FindNodeOutputInfo(gt, params.node_id());
ASSERT_NE(nullptr, output_params);
EXPECT_EQ(params.output_count(), output_params->max_byte_size_size());
EXPECT_EQ(params.node_id(), output_params->node_id());
for (const int max_size : output_params->max_byte_size()) {
EXPECT_GE(max_size, 0);
}
}
}
}
TEST_F(GraphTransfererTest, LoadAddGraph) {
GraphDef def = CreateAddGraphDef();
ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
{}, std::vector<string>{NAME_A_PLUS_B},
false)
.ok());
SanityCheckNodes(gt_);
const int const_node_count =
gt_.GetGraphTransferInfo().const_node_info_size();
ASSERT_EQ(2, const_node_count);
const GraphTransferConstNodeInfo* params_a = FindConstNodeInfo(gt_, NAME_A);
ASSERT_TRUE(params_a != nullptr);
EXPECT_EQ(NAME_A, params_a->name());
ASSERT_EQ(4, params_a->shape_size());
EXPECT_EQ(1, params_a->shape(0));
EXPECT_EQ(1, params_a->shape(1));
EXPECT_EQ(1, params_a->shape(2));
EXPECT_EQ(1, params_a->shape(3));
EXPECT_EQ(4, params_a->data().length());
const GraphTransferConstNodeInfo* params_b = FindConstNodeInfo(gt_, NAME_B);
ASSERT_TRUE(params_b != nullptr);
ASSERT_EQ(4, params_b->shape_size());
EXPECT_EQ(1, params_b->shape(0));
EXPECT_EQ(1, params_b->shape(1));
EXPECT_EQ(1, params_b->shape(2));
EXPECT_EQ(1, params_b->shape(3));
EXPECT_EQ(4, params_b->data().length());
}
TEST_F(GraphTransfererTest, LoadAddGraphWithOutputTensorMap) {
GraphDef def = CreateAddGraphDef();
std::pair<string, Tensor> input_node_info_a;
input_node_info_a.first = NAME_A;
input_node_info_a.second = Tensor(DT_FLOAT, {});
input_node_info_a.second.scalar<float>()() = 1.0f;
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a};
RemoteFusedGraphExecuteUtils::TensorShapeMap output_tensor_info;
Status status = RemoteFusedGraphExecuteUtils::DryRunInferenceForAllNode(
def, inputs, {}, &output_tensor_info);
ASSERT_TRUE(status.ok()) << status;
const std::vector<string> output_node_names = {NAME_A_PLUS_B};
status = gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
inputs, output_node_names, false);
TF_ASSERT_OK(status);
}
TEST_F(GraphTransfererTest, LoadConvGraph) {
GraphDef def = CreateConvGraphDef();
std::vector<std::pair<string, Tensor>> input_node_info_list;
input_node_info_list.emplace_back(
std::pair<string, Tensor>{"input", Tensor{DT_FLOAT, {1, 1, 1, 1}}});
const std::vector<string> output_node_names = {"softmax"};
ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
input_node_info_list, output_node_names,
false)
.ok());
SanityCheckNodes(gt_);
const int const_node_count =
gt_.GetGraphTransferInfo().const_node_info_size();
ASSERT_EQ(2, const_node_count);
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(4, op_node_count);
const GraphTransferNodeInfo* params_conv = FindNodeInfo(gt_, "conv");
ASSERT_TRUE(params_conv != nullptr);
const int id = params_conv->node_id();
EXPECT_GE(id, 0);
EXPECT_EQ("Conv2D", params_conv->type_name());
EXPECT_EQ(3, params_conv->input_count());
EXPECT_EQ(1, params_conv->output_count());
EXPECT_EQ(Padding::SAME, params_conv->padding_id());
}
TEST_F(GraphTransfererTest, LoadMaxPoolGraph) {
GraphDef def = CreatePoolGraphDef();
std::vector<std::pair<string, Tensor>> input_node_info_list;
input_node_info_list.emplace_back(
std::pair<string, Tensor>{"input", Tensor{DT_FLOAT, {1, 1, 1, 1}}});
const std::vector<string> output_node_names = {"softmax"};
ASSERT_TRUE(gt_.LoadGraphFromProto(TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, def,
input_node_info_list, output_node_names,
false)
.ok());
SanityCheckNodes(gt_);
const int const_node_count =
gt_.GetGraphTransferInfo().const_node_info_size();
ASSERT_EQ(2, const_node_count);
const int op_node_count = gt_.GetGraphTransferInfo().node_info_size();
ASSERT_EQ(4, op_node_count);
const GraphTransferNodeInfo* params_max_pool = FindNodeInfo(gt_, "maxpool");
ASSERT_TRUE(params_max_pool != nullptr);
const int id = params_max_pool->node_id();
EXPECT_GE(id, 0);
EXPECT_EQ("MaxPool", params_max_pool->type_name());
EXPECT_EQ(3, params_max_pool->input_count());
EXPECT_EQ(1, params_max_pool->output_count());
EXPECT_EQ(Padding::SAME, params_max_pool->padding_id());
}
TEST(HexagonOpsDefinitions, CheckOpsDefinitions) {
const IRemoteFusedGraphOpsDefinitions& ops_definitions =
HexagonOpsDefinitions::getInstance();
const int total_ops_count = ops_definitions.GetTotalOpsCount();
EXPECT_GT(total_ops_count, 0);
}
TEST(GraphTransferer, LoadGraphFromProtoFile) {
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
string filename =
io::JoinPath(testing::TensorFlowSrcRoot(),
"core/example/testdata/parse_example_graph_def.pbtxt");
std::vector<std::pair<string, Tensor>> input_node_info_list = {};
std::vector<string> output_node_names = {};
bool is_text_proto = true;
// Keep following comments for debugging purpose for now
// filename = "v3_stripped_quantized_graph_opt.pb";
// input_node_info_list.emplace_back(
// std::pair<string, Tensor>{"Mul", Tensor{DT_FLOAT, {1,299,299,3}}});
// output_node_names.emplace_back("softmax");
// is_text_proto = false;
// ops_definitions = &HexagonOpsDefinitions::getInstance();
GraphTransferer gt;
gt.EnableStrictCheckMode(false);
Status status = gt.LoadGraphFromProtoFile(
*ops_definitions, filename, input_node_info_list, output_node_names,
is_text_proto, false, true);
}
TEST_F(GraphTransfererTest, BuildRemoteFusedGraphDefAddGraph) {
GraphDef def = CreateAddGraphDef();
std::pair<string, Tensor> input_node_info_a;
input_node_info_a.first = NAME_A;
input_node_info_a.second = Tensor(DT_FLOAT, {});
input_node_info_a.second.scalar<float>()() = 1.0f;
std::pair<string, Tensor> input_node_info_b;
input_node_info_b.first = NAME_B;
input_node_info_b.second = Tensor(DT_FLOAT, {});
input_node_info_b.second.scalar<float>()() = 10.0f;
const std::vector<std::pair<string, Tensor>> inputs{input_node_info_a,
input_node_info_b};
std::vector<string> outputs = {NAME_A_PLUS_B};
GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef(
TEST_GRAPH_TRANSFER_OPS_DEFINITIONS, "remote_fused_graph_execute_node",
inputs, outputs, &def);
EXPECT_EQ(3, fused_graph_def.node_size());
}
namespace {
// Just compares the max_byte_size attributes present.
void CompareGraphTransferInfo(const GraphTransferInfo& a,
const GraphTransferInfo& b) {
EXPECT_EQ(a.node_output_info_size(), b.node_output_info_size());
for (int i = 0; i < a.node_output_info_size(); ++i) {
EXPECT_EQ(a.node_output_info(i).node_id(), b.node_output_info(i).node_id());
EXPECT_EQ(a.node_output_info(i).max_byte_size_size(),
b.node_output_info(i).max_byte_size_size());
for (int j = 0; j < a.node_output_info(i).max_byte_size_size(); ++j) {
EXPECT_EQ(a.node_output_info(i).max_byte_size(j),
b.node_output_info(i).max_byte_size(j));
}
}
}
} // anonymous namespace
TEST(GraphTransferer, LoadGraphFromProtoFileShapeInferenceSimple) {
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&TEST_GRAPH_TRANSFER_OPS_DEFINITIONS;
string filename =
io::JoinPath(testing::TensorFlowSrcRoot(),
"core/example/testdata/parse_example_graph_def.pbtxt");
std::vector<std::pair<string, Tensor>> input_node_info_list = {};
std::vector<string> output_node_names = {};
bool is_text_proto = true;
// In order to run with a more complex graph uncomment the following lines
// filename = "v3_stripped_quantized_graph_opt.pb";
// input_node_info_list.emplace_back(
// std::pair<string, Tensor>{"Mul", Tensor{DT_FLOAT, {1,299,299,3}}});
// output_node_names.emplace_back("softmax");
// is_text_proto = false;
// ops_definitions = &HexagonOpsDefinitions::getInstance();
// First compute using Shape inference.
GraphTransferer si_gt;
si_gt.EnableStrictCheckMode(false);
bool shape_inference_for_unknown_shape = true;
bool dry_run_for_unknown_shape = false;
Status status1 = si_gt.LoadGraphFromProtoFile(
*ops_definitions, filename, input_node_info_list, output_node_names,
is_text_proto, shape_inference_for_unknown_shape,
dry_run_for_unknown_shape);
const GraphTransferInfo& si_graph_transfer_info =
si_gt.GetGraphTransferInfo();
// Now compute using dry run.
GraphTransferer dr_gt;
dr_gt.EnableStrictCheckMode(false);
shape_inference_for_unknown_shape = false;
dry_run_for_unknown_shape = true;
Status status2 = dr_gt.LoadGraphFromProtoFile(
*ops_definitions, filename, input_node_info_list, output_node_names,
is_text_proto, shape_inference_for_unknown_shape,
dry_run_for_unknown_shape);
const GraphTransferInfo& dr_graph_transfer_info =
dr_gt.GetGraphTransferInfo();
// Now compare both of them.
CompareGraphTransferInfo(si_graph_transfer_info, dr_graph_transfer_info);
}
} // namespace tensorflow

View File

@ -1,437 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
vcyou may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/hexagon/soc_interface.h"
#include "tensorflow/core/platform/profile_utils/cpu_utils.h"
namespace tensorflow {
constexpr const char* const OUTPUT_OP_NAME = "OUTPUT";
constexpr const char* const REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX =
"hexagon_remote_fused_graph";
/* static */ constexpr const char* const
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME;
constexpr int ALIGNMENT_BYTES = 16;
constexpr int MAX_IN_OUT_COUNT = 128;
const bool DBG_DUMP_VERIFICATION_STRING = false;
const int DBG_LEVEL = 0; // -2: verbose, -1: debug, 0: info
const bool DBG_USE_DUMMY_INPUT = false;
const bool DBG_USE_SAMPLE_INPUT = false;
const int64 FLAG_ENABLE_PANDA_BINARY_INPUT = 0x01;
const bool DBG_DUMP_INPUT_TENSOR_AS_FLOAT_DATA = false;
static string AddPort(const string& node_name) {
if (node_name.find(':') != string::npos) {
return node_name;
} else {
return strings::StrCat(node_name, ":", 0);
}
}
static uint8* FindAlignedPointer(uint8* ptr) {
const uintptr_t data_ptr_int = reinterpret_cast<uintptr_t>(ptr);
const int shift_count =
(ALIGNMENT_BYTES - data_ptr_int % ALIGNMENT_BYTES) % ALIGNMENT_BYTES;
uint8* data_ptr = ptr + shift_count;
return data_ptr;
}
/* static */ GraphTransferNodeInfo* HexagonControlWrapper::FindNodeInfo(
const string& name, GraphTransferInfo* graph_transfer_info) {
for (GraphTransferNodeInfo& node_info :
*graph_transfer_info->mutable_node_info()) {
if (node_info.name() == name) {
return &node_info;
}
}
return nullptr;
}
int HexagonControlWrapper::GetVersion() {
return soc_interface_GetSocControllerVersion();
}
bool HexagonControlWrapper::Init(const RemoteFusedGraphExecuteInfo& info) {
soc_interface_SetLogLevel(DBG_LEVEL);
if (DBG_USE_SAMPLE_INPUT) {
soc_interface_SetDebugFlag(FLAG_ENABLE_PANDA_BINARY_INPUT);
}
if (info.serialized_executor_parameters().empty()) {
std::vector<std::pair<string, Tensor>> inputs;
std::vector<string> outputs;
RemoteFusedGraphExecuteUtils::BuildRemoteGraphInputsAndOutputsFromProto(
info, &inputs, &outputs);
Status status = graph_transferer_.LoadGraphFromProto(
HexagonOpsDefinitions::getInstance(), info.remote_graph(), inputs,
outputs,
false // shape_inference_for_unknown_shape
);
TF_CHECK_OK(status) << status;
} else {
// If graph transfer info is attached, just import it.
graph_transferer_.SetSerializedGraphTransferInfo(
info.serialized_executor_parameters());
}
execute_info_ = &info;
bool success = soc_interface_Init();
if (!success) {
LOG(ERROR) << "Hexagon initialization was failed. See log output.";
return false;
}
std::vector<int> input_sizes;
std::vector<int> output_sizes;
CHECK_NOTNULL(execute_info_);
for (int i = 0; i < execute_info_->graph_input_node_name_size(); ++i) {
const string& input = execute_info_->graph_input_node_name(i);
LOG(INFO) << "Add input: " << input << ", " << i;
CHECK(input_port_map_.emplace(AddPort(input), i).second);
const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type =
execute_info_->default_graph_input_tensor_shape(i);
int64 buf_size = DataTypeSize(shape_type.dtype());
for (const TensorShapeProto::Dim& dim : shape_type.shape().dim()) {
buf_size *= dim.size();
}
input_sizes.emplace_back(static_cast<int>(buf_size));
}
for (int i = 0; i < execute_info_->graph_output_node_name_size(); ++i) {
const string& output = execute_info_->graph_output_node_name(i);
CHECK(output_port_map_.emplace(AddPort(output), i).second);
const RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& shape_type =
execute_info_->default_graph_output_tensor_shape(i);
int64 buf_size = DataTypeSize(shape_type.dtype());
for (const TensorShapeProto::Dim& dim : shape_type.shape().dim()) {
buf_size *= dim.size();
}
output_sizes.emplace_back(static_cast<int>(buf_size));
}
LOG(INFO) << "Allocate inout buffer";
success &= soc_interface_AllocateInOutNodeBuffers(
input_sizes.size(), input_sizes.data(), output_sizes.size(),
output_sizes.data());
return success;
}
bool HexagonControlWrapper::Finalize() { return soc_interface_Finalize(); }
bool HexagonControlWrapper::SetupGraph() {
// Copy graph transfer info to modify to adapt hexnn library
GraphTransferInfo& graph_transfer_info =
graph_transferer_.GetMutableGraphTransferInfo();
// Overwrite op type of input nodes for hexagon
for (const GraphTransferGraphInputNodeInfo& graph_input :
graph_transfer_info.graph_input_node_info()) {
GraphTransferNodeInfo* node_info =
FindNodeInfo(graph_input.name(), &graph_transfer_info);
CHECK_NE(node_info, nullptr);
}
// Generate a new output node which is connected to graph output node
// TODO(satok): Support multiple output nodes
CHECK_EQ(graph_transfer_info.graph_output_node_info_size(), 1);
for (const GraphTransferGraphOutputNodeInfo& graph_output :
graph_transfer_info.graph_output_node_info()) {
const int new_output_node_id = graph_transfer_info.node_info_size() +
graph_transfer_info.const_node_info_size() +
2 /* offset for ids */;
// Register a new output node
GraphTransferNodeInfo& new_output_node_info =
*graph_transfer_info.add_node_info();
new_output_node_info.set_name(OUTPUT_OP_NAME);
new_output_node_info.set_node_id(new_output_node_id);
new_output_node_info.set_type_name(OUTPUT_OP_NAME);
new_output_node_info.set_soc_op_id(
HexagonOpsDefinitions::getInstance().GetOpIdFor(OUTPUT_OP_NAME, {}));
new_output_node_info.set_padding_id(0 /* PADDING_NA_ID */);
new_output_node_info.set_input_count(1);
new_output_node_info.set_output_count(0);
const TensorId tid = ParseTensorName(graph_output.name());
const string node_name(tid.first);
const int port = tid.second;
// Register node input for the new output node
const GraphTransferNodeInfo* node_info =
FindNodeInfo(node_name, &graph_transfer_info);
CHECK_NE(node_info, nullptr);
GraphTransferNodeInputInfo& node_input_info =
*graph_transfer_info.add_node_input_info();
node_input_info.set_node_id(new_output_node_id);
GraphTransferNodeInput& node_input = *node_input_info.add_node_input();
node_input.set_node_id(node_info->node_id());
node_input.set_output_port(port);
}
if (DBG_DUMP_VERIFICATION_STRING) {
GraphTransferer gt;
gt.SetSerializedGraphTransferInfo(graph_transfer_info.SerializeAsString());
gt.DumpVerificationStringOfNodeTransferParams();
}
int inputs_count = 0;
int outputs_count = 0;
for (const GraphTransferNodeInputInfo& input_params :
graph_transfer_info.node_input_info()) {
inputs_count += input_params.node_input_size();
}
for (const GraphTransferNodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) {
outputs_count += output_params.max_byte_size_size();
}
// Allocate memory for node inputs and node outputs
soc_interface_AllocateNodeInputAndNodeOutputArray(inputs_count,
outputs_count);
// Construct node input parameters
std::unordered_map<int, std::tuple<void*, int>> inputs_map;
for (const GraphTransferNodeInputInfo& input_params :
graph_transfer_info.node_input_info()) {
const int count = input_params.node_input_size();
CHECK(count <= MAX_IN_OUT_COUNT);
int node_ids[MAX_IN_OUT_COUNT];
int ports[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) {
const GraphTransferNodeInput& node_input = input_params.node_input(i);
node_ids[i] = node_input.node_id() + NODE_ID_OFFSET;
ports[i] = node_input.output_port();
}
void* inputs_ptr = soc_interface_SetOneNodeInputs(count, node_ids, ports);
const int node_id = input_params.node_id();
CHECK(inputs_map.count(node_id) == 0);
inputs_map.emplace(node_id, std::make_tuple(inputs_ptr, count));
}
// Construct node output parameters
std::unordered_map<int, std::tuple<void*, int>> outputs_map;
for (const GraphTransferNodeOutputInfo& output_params :
graph_transfer_info.node_output_info()) {
const int count = output_params.max_byte_size_size();
CHECK(count <= MAX_IN_OUT_COUNT);
int sizes[MAX_IN_OUT_COUNT];
for (int i = 0; i < count; ++i) {
const int size = output_params.max_byte_size(i);
sizes[i] = size;
}
void* outputs_ptr = soc_interface_SetOneNodeOutputs(count, sizes);
const int node_id = output_params.node_id();
CHECK(outputs_map.count(node_id) == 0);
outputs_map.emplace(node_id, std::make_tuple(outputs_ptr, count));
}
// Instantiate graph
soc_interface_InstantiateGraph();
// Initialize graph
// 1. Setup const nodes
for (const GraphTransferConstNodeInfo& params :
graph_transfer_info.const_node_info()) {
const int node_id = params.node_id();
// TODO(satok): Stop assuming shape size is 4.
CHECK(params.shape_size() == 4);
const int64 shape_0 = params.shape(0);
const int64 shape_1 = params.shape(1);
const int64 shape_2 = params.shape(2);
const int64 shape_3 = params.shape(3);
const int data_size = params.data().length();
CHECK(dummy_const_data_.count(node_id) == 0);
auto data = dummy_const_data_.emplace(
std::piecewise_construct, std::make_tuple(node_id), std::make_tuple());
CHECK(data.second);
data.first->second.resize(data_size + ALIGNMENT_BYTES - 1);
uint8* data_ptr = FindAlignedPointer(data.first->second.data());
std::memcpy(data_ptr, params.data().data(), data_size);
soc_interface_AppendConstNode(params.name().c_str(),
node_id + NODE_ID_OFFSET, shape_0, shape_1,
shape_2, shape_3, data_ptr, data_size);
}
// 2. Setup op nodes
for (const GraphTransferNodeInfo& params : graph_transfer_info.node_info()) {
const int node_id = params.node_id();
const int op_id = params.soc_op_id();
CHECK(inputs_map.count(node_id) == 1);
CHECK(outputs_map.count(node_id) <= 1);
// Only output node doesn't have output
const bool has_output = outputs_map.count(node_id) == 1;
const auto& input_ptr_and_count = inputs_map.at(node_id);
const void* input_ptr = std::get<0>(input_ptr_and_count);
const int input_count = std::get<1>(input_ptr_and_count);
void* output_ptr = nullptr;
int output_count = 0;
if (has_output) {
const auto& output_ptr_and_count = outputs_map.at(node_id);
output_ptr = std::get<0>(output_ptr_and_count);
output_count = std::get<1>(output_ptr_and_count);
// CHECK(output_count > 0);
}
int padding_id = -1;
if (params.padding_id() == 0) {
padding_id = 0;
} else if (params.padding_id() == Padding::SAME) {
padding_id = 1;
} else if (params.padding_id() == Padding::VALID) {
padding_id = 2;
} else {
LOG(FATAL);
}
soc_interface_AppendNode(params.name().c_str(), node_id + NODE_ID_OFFSET,
op_id, padding_id, input_ptr, input_count,
output_ptr, output_count);
}
LOG(INFO) << "Setup graph completed";
// 3. construct graph
return soc_interface_ConstructGraph();
// Keep following comment to use dummy graph construction
// return soc_interface_setupDummyGraph(3 /* inception version */);
}
bool HexagonControlWrapper::ExecuteGraph() {
return soc_interface_ExecuteGraph();
}
bool HexagonControlWrapper::TeardownGraph() {
soc_interface_ReleaseNodeInputAndNodeOutputArray();
return soc_interface_TeardownGraph();
}
bool HexagonControlWrapper::FillInputNode(
const string& node_name,
const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>& shape,
const ConstByteArray bytes) {
const string tensor_name = AddPort(node_name);
CHECK(input_port_map_.count(tensor_name) > 0);
const int port = input_port_map_.at(tensor_name);
if (input_tensor_data_.count(port) <= 0) {
input_tensor_data_.emplace(port, std::vector<uint8>{});
}
std::vector<uint8>& input_tensor_data = input_tensor_data_.at(port);
// hexagon only supports 32bit dimension
const int x = static_cast<int>(shape[0]);
const int y = static_cast<int>(shape[1]);
const int z = static_cast<int>(shape[2]);
const int d = static_cast<int>(shape[3]);
const uint64 byte_size = x * y * z * d * DataTypeSize(std::get<2>(bytes));
CHECK_EQ(byte_size, std::get<1>(bytes));
input_tensor_data.resize(byte_size + ALIGNMENT_BYTES);
uint8* data_ptr = FindAlignedPointer(input_tensor_data.data());
if (DBG_USE_DUMMY_INPUT) {
std::memset(data_ptr, 0, byte_size);
} else {
std::memcpy(data_ptr, std::get<0>(bytes), byte_size);
}
return soc_interface_FillInputNodeWithPort(port, x, y, z, d, data_ptr,
byte_size);
}
bool HexagonControlWrapper::ReadOutputNode(
const string& node_name, TensorAllocatorFunc tensor_allocator) {
CHECK_NE(execute_info_, nullptr);
TensorShape output_shape;
// TODO(satok): Switch shape corresponding to input shape
for (int i = 0; i < execute_info_->graph_output_node_name_size(); ++i) {
if (execute_info_->graph_output_node_name(i) == node_name) {
for (const TensorShapeProto::Dim& dim :
execute_info_->default_graph_output_tensor_shape(i).shape().dim()) {
output_shape.AddDim(dim.size());
}
break;
}
}
std::vector<ByteArray> outputs;
ReadOutputNode(node_name, &outputs);
CHECK_EQ(1, outputs.size());
ByteArray& output = outputs[0];
Tensor* output_tensor = tensor_allocator(output_shape);
CHECK(output_tensor->TotalBytes() >= std::get<1>(output))
<< output_tensor->TotalBytes() << ", " << std::get<1>(output);
TF_CHECK_OK(RemoteFusedGraphExecuteUtils::CopyByteArrayToTensor(
std::get<0>(output), std::get<1>(output), output_tensor));
return true;
}
bool HexagonControlWrapper::ReadOutputNode(
const string& node_name, std::vector<ByteArray>* const outputs) {
CHECK(outputs != nullptr);
ByteArray output;
const string tensor_name = AddPort(node_name);
CHECK(output_port_map_.count(tensor_name) > 0);
const int port = output_port_map_.at(tensor_name);
soc_interface_ReadOutputNodeWithPort(
port, &std::get<0>(output),
reinterpret_cast<uint64_t*>(&std::get<1>(output)));
// TODO: Accept all results
// std::get<2>(output) = DT_FLOAT;
outputs->emplace_back(output);
return true;
}
Status HexagonControlWrapper::FuseRemoteGraph(
const GraphDef& original_graph_def, const std::vector<string>& inputs,
const std::vector<string>& outputs, GraphDef* fused_graph_def) {
const std::unordered_set<string> fused_node_names =
RemoteFusedGraphExecuteUtils::BuildNodeMapFromOpsDefinitions(
original_graph_def, HexagonOpsDefinitions::getInstance());
// TODO(satok): We may want to place shape and type inside this function
// if they are not placed in the given graph.
TF_RETURN_IF_ERROR(RemoteFusedGraphExecuteUtils::FuseRemoteGraphByNodeNames(
original_graph_def, inputs, outputs, REMOTE_FUSED_GRAPH_NODE_NAME_PREFIX,
fused_node_names, REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
/*require_shape_type=*/true, fused_graph_def));
return Status::OK();
}
bool HexagonControlWrapper::FillInputNode(const string& node_name,
const Tensor& tensor) {
StringPiece tensor_data = tensor.tensor_data();
const ConstByteArray ba =
ConstByteArray(reinterpret_cast<const uint8*>(tensor_data.data()),
tensor_data.size(), tensor.dtype());
if (DBG_DUMP_INPUT_TENSOR_AS_FLOAT_DATA) {
LOG(INFO) << "Input tensor data: element size = " << tensor.NumElements()
<< ", byte syze = " << tensor.TotalBytes();
std::stringstream line;
for (int i = 0; i < tensor.NumElements(); ++i) {
line << tensor.flat<float>().data()[i] << ", ";
if ((i - 2) % 3 == 0 || i == tensor.NumElements() - 1) {
LOG(INFO) << "(" << ((i - 2) / 3) << ") " << line.str();
line.str("");
line.clear();
}
}
}
const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE> shape =
GraphTransferer::ToTensorShapeArray(tensor.shape());
FillInputNode(node_name, shape, ba);
return true;
}
bool HexagonControlWrapper::IsEnabled() const { return true; };
} // namespace tensorflow

View File

@ -1,91 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
vcyou may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_
#include <unordered_map>
#include <vector>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
/*
HexagonControlWrapper is implementing interfaces in IRemoteFusedGraphExecutor.
This class calls APIs on hexagon via hexagon control binary.
TODO(satok): Add more documents about hexagon control binary.
*/
class HexagonControlWrapper final : public IRemoteFusedGraphExecutor {
public:
using ByteArray =
std::tuple<uint8* /* data */, uint64 /* size */, DataType /* type */>;
static constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTOR_NAME =
"build_hexagon_remote_fused_graph_executor";
HexagonControlWrapper() = default;
int GetVersion() final;
bool Init(const RemoteFusedGraphExecuteInfo& info) final;
bool Finalize() final;
bool SetupGraph() final;
bool ExecuteGraph() final;
bool TeardownGraph() final;
bool FillInputNode(const string& node_name, const Tensor& tensor) final;
bool ReadOutputNode(const string& node_name,
TensorAllocatorFunc tensor_allocator) final;
Status FuseRemoteGraph(const GraphDef& original_graph_def,
const std::vector<string>& inputs,
const std::vector<string>& outputs,
GraphDef* fused_graph_def) final;
bool IsEnabled() const final;
bool ReadOutputNode(const string& node_name, std::vector<ByteArray>* outputs);
private:
using ConstByteArray = std::tuple<const uint8* /* data */, uint64 /* size */,
DataType /* type */>;
bool FillInputNode(
const string& node_name,
const std::array<int64, GraphTransferer::SHAPE_ARRAY_SIZE>& shape,
const ConstByteArray bytes);
// CAVEAT: Need offset as HVX library reserves some ids
static constexpr int NODE_ID_OFFSET = 0x10000;
static GraphTransferNodeInfo* FindNodeInfo(
const string& name, GraphTransferInfo* graph_transfer_info);
const RemoteFusedGraphExecuteInfo* execute_info_{};
GraphTransferer graph_transferer_{};
// Dummy float array for input node.
// TODO(satok): Use actual data passed by FillInputNode and remove
// std::vector<float> dummy_input_float_{};
std::unordered_map<int, std::vector<uint8>> input_tensor_data_{};
// Dummy byte array for const node.
// TODO(satok): Remove
std::unordered_map<int, std::vector<uint8>> dummy_const_data_{};
std::unordered_map<string, int> input_port_map_{};
std::unordered_map<string, int> output_port_map_{};
TF_DISALLOW_COPY_AND_ASSIGN(HexagonControlWrapper);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_CONTROL_WRAPPER_H_

View File

@ -1,604 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
/* Before calling this test program, download a model as follows.
$ curl
https://storage.googleapis.com/download.tensorflow.org/models/tensorflow_inception_v3_stripped_optimized_quantized.pb
\ -o /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb
$ adb push /tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
/data/local/tmp
$ curl
https://storage.googleapis.com/download.tensorflow.org/models/imagenet_comp_graph_label_strings.txt
-o /tmp/imagenet_comp_graph_label_strings.txt
adb push /tmp/imagenet_comp_graph_label_strings.txt /data/local/tmp
*/
// define EIGEN_USE_THREADS to include quantization_utils.h
#define EIGEN_USE_THREADS
#include <memory>
#include "absl/base/casts.h"
#include "tensorflow/core/framework/graph_transfer_info.pb.h"
#include "tensorflow/core/framework/remote_fused_graph_execute_info.pb.h"
#include "tensorflow/core/framework/tensor_shape.pb.h"
#include "tensorflow/core/framework/tensor_testutil.h"
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
#include "tensorflow/core/kernels/hexagon/graph_transferer.h"
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/kernels/quantization_utils.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/profile_utils/clock_cycle_profiler.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/public/session.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
using ByteArray = HexagonControlWrapper::ByteArray;
constexpr const char* const IMAGE_FILENAME = "/data/local/tmp/img_299x299.bmp";
constexpr const char* const MODEL_FILENAME =
"/data/local/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb";
constexpr const char* const MODEL_WITH_QUANTIZED_INPUT_FILENAME =
"/data/local/tmp/"
"tensorflow_inception_v3_stripped_optimized_quantized_with_quantized_input."
"pb";
constexpr const char* const FUSED_MODEL_FILENAME =
"/data/local/tmp/"
"tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb";
constexpr const char* const REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME =
"remote_fused_graph_execute_node";
constexpr bool USE_SHAPE_INFERENCE = false;
const bool DBG_DUMP_FLOAT_DATA = false;
const int WIDTH = 299;
const int HEIGHT = 299;
const int DEPTH = 3;
const int EXPECTED_FIRST_RESULT_ID = 59;
const int EXECUTION_REPEAT_COUNT = 10;
static void CheckHexagonControllerVersion() {
HexagonControlWrapper hexagon_control_wrapper;
const int version = hexagon_control_wrapper.GetVersion();
ASSERT_GE(version, 1);
LOG(INFO) << "Hexagon controller version is " << version;
}
static void DumpTop10Results(const int byte_size,
const float* const float_array) {
const int element_count = byte_size / sizeof(float);
const string label_filename =
"/data/local/tmp/imagenet_comp_graph_label_strings.txt";
string label_str;
TF_CHECK_OK(ReadFileToString(Env::Default(), label_filename, &label_str));
std::vector<string> labels = str_util::Split(label_str, '\n');
GraphTransferUtils::DumpTopNFloatResults(
float_array, labels.data(),
std::min(element_count, static_cast<int>(labels.size())),
10 /* show top_n results */);
}
static void DumpTop10Results(const std::vector<ByteArray>& outputs) {
CHECK(outputs.size() == 1);
const int byte_size = std::get<1>(outputs.at(0));
const float* float_array =
reinterpret_cast<float*>(std::get<0>(outputs.at(0)));
DumpTop10Results(byte_size, float_array);
}
static void CheckFirstResult(const std::vector<ByteArray>& outputs,
const int expected_first_id) {
EXPECT_GE(outputs.size(), 1);
const int byte_size = std::get<1>(outputs.at(0));
const int element_count = byte_size / sizeof(float);
const float* float_array =
reinterpret_cast<float*>(std::get<0>(outputs.at(0)));
EXPECT_GE(element_count, 1);
std::vector<string> labels(element_count);
std::priority_queue<std::tuple<float, int, string>> queue =
GraphTransferUtils::GetTopNFloatResults(float_array, labels.data(),
element_count);
const std::tuple<float, int, string>& entry = queue.top();
EXPECT_EQ(expected_first_id, std::get<1>(entry));
}
static void LoadImage(std::vector<float>* img_floats_ptr) {
CHECK(img_floats_ptr != nullptr);
std::vector<float>& img_floats = *img_floats_ptr;
// Read the data from the bitmap file into memory
string bmp;
TF_CHECK_OK(ReadFileToString(Env::Default(), IMAGE_FILENAME, &bmp));
const int fsize = bmp.size();
LOG(INFO) << "Read " << IMAGE_FILENAME << ", size = " << fsize << "bytes";
const int64 pixel_count = WIDTH * HEIGHT * DEPTH;
CHECK(fsize >= 22 /* pos of height */ + sizeof(int));
CHECK(bmp.data() != nullptr);
uint8* const img_bytes = absl::bit_cast<uint8*>(bmp.data());
const int header_size = *(reinterpret_cast<int*>(img_bytes + 10));
LOG(INFO) << "header size = " << header_size;
const int size = *(reinterpret_cast<int*>(img_bytes + 14));
LOG(INFO) << "image size = " << size;
const int width = *(reinterpret_cast<int*>(img_bytes + 18));
LOG(INFO) << "width = " << width;
const int height = *(reinterpret_cast<int*>(img_bytes + 22));
LOG(INFO) << "height = " << height;
CHECK(fsize >= (WIDTH + 1) * WIDTH * 3 + header_size);
uint8* const bmp_pixels = &img_bytes[header_size];
img_floats.resize(pixel_count);
int src_pixel_index = 0;
CHECK(pixel_count % 3 == 0);
for (int i = 0; i < pixel_count / 3; ++i) {
const int src_pos = 3 * src_pixel_index;
const int dst_pos = 3 * i;
++src_pixel_index;
CHECK(src_pos + 2 + header_size < fsize);
CHECK(dst_pos + 2 < pixel_count);
// Convert (B, G, R) in bitmap to (R, G, B)
img_floats[dst_pos] =
(static_cast<float>(bmp_pixels[src_pos + 2]) - 128.0f) / 128.0f;
img_floats[dst_pos + 1] =
(static_cast<float>(bmp_pixels[src_pos + 1]) - 128.0f) / 128.0f;
img_floats[dst_pos + 2] =
(static_cast<float>(bmp_pixels[src_pos]) - 128.0f) / 128.0f;
if (DBG_DUMP_FLOAT_DATA) {
LOG(INFO) << i << " (" << img_floats[dst_pos] << ", "
<< img_floats[dst_pos + 1] << ", " << img_floats[dst_pos + 2]
<< ") (" << static_cast<int>(bmp_pixels[src_pos + 2]) << ", "
<< static_cast<int>(bmp_pixels[src_pos + 1]) << ", "
<< static_cast<int>(bmp_pixels[src_pos]) << ")";
}
if (src_pixel_index % (WIDTH + 1) == (WIDTH - 1)) {
// skip bmp padding
++src_pixel_index;
}
}
}
static void QuantizeImage(const std::vector<float>& float_vec,
std::vector<quint8>* quint8_vec) {
quint8_vec->resize(float_vec.size());
for (int i = 0; i < float_vec.size(); ++i) {
quint8_vec->at(i) = FloatToQuantized<quint8>(float_vec[i], -1.0f, 1.0f);
}
}
static Tensor BuildImageTensor(const std::vector<float>& img_floats) {
LOG(INFO) << "Loading image finished.";
Tensor img_tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH});
CHECK_EQ(WIDTH * HEIGHT * DEPTH, img_floats.size());
CHECK_EQ(img_tensor.TotalBytes(), img_floats.size() * sizeof(float));
LOG(INFO) << "Copy data to tensor.";
std::memcpy(img_tensor.flat<float>().data(), img_floats.data(),
img_tensor.TotalBytes());
return img_tensor;
}
static Tensor BuildQuantizedImageTensor(
const std::vector<quint8>& quantized_img) {
LOG(INFO) << "Loading image finished.";
Tensor img_tensor(DT_QUINT8, {1, WIDTH, HEIGHT, DEPTH});
CHECK_EQ(WIDTH * HEIGHT * DEPTH, quantized_img.size());
CHECK_EQ(img_tensor.TotalBytes(), quantized_img.size() * sizeof(quint8));
LOG(INFO) << "Copy data to tensor.";
std::memcpy(img_tensor.flat<quint8>().data(), quantized_img.data(),
img_tensor.TotalBytes());
return img_tensor;
}
/* static */ RemoteFusedGraphExecuteInfo
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
const GraphTransferInfo& graph_transfer_info) {
RemoteFusedGraphExecuteInfo execute_info;
execute_info.set_executor_name("build_hexagon_remote_fused_graph_executor");
for (const GraphTransferGraphInputNodeInfo& input :
graph_transfer_info.graph_input_node_info()) {
execute_info.add_graph_input_node_name(input.name());
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
*execute_info.add_default_graph_input_tensor_shape();
tensor_shape_type.set_dtype(input.dtype());
TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
for (const int64 dim : input.shape()) {
tensor_shape_proto.add_dim()->set_size(dim);
}
}
for (const GraphTransferGraphOutputNodeInfo& output :
graph_transfer_info.graph_output_node_info()) {
execute_info.add_graph_output_node_name(output.name());
RemoteFusedGraphExecuteInfo::TensorShapeTypeProto& tensor_shape_type =
*execute_info.add_default_graph_output_tensor_shape();
tensor_shape_type.set_dtype(output.dtype());
TensorShapeProto& tensor_shape_proto = *tensor_shape_type.mutable_shape();
for (const int64 dim : output.shape()) {
tensor_shape_proto.add_dim()->set_size(dim);
}
}
execute_info.set_serialized_executor_parameters(
graph_transfer_info.SerializeAsString());
return execute_info;
}
static void RunInferenceByHexagonControlWrapper(const GraphTransferer& gt,
const Tensor& img_tensor) {
const RemoteFusedGraphExecuteInfo execute_info =
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(
gt.GetGraphTransferInfo());
HexagonControlWrapper hexagon_control_wrapper;
// 1. Initialize hexagon
hexagon_control_wrapper.Init(execute_info);
// 2. Setup graph in hexagon
hexagon_control_wrapper.SetupGraph();
// 3. Fill input node's output
hexagon_control_wrapper.FillInputNode("Mul", img_tensor);
// 4. Execute graph
const int64 start_time_us = Env::Default()->NowMicros();
for (int i = 0; i < EXECUTION_REPEAT_COUNT; ++i) {
hexagon_control_wrapper.ExecuteGraph();
}
const int64 end_time_us = Env::Default()->NowMicros();
// 5-1. Read output node's outputs
std::vector<ByteArray> outputs;
hexagon_control_wrapper.ReadOutputNode("softmax", &outputs);
// 5-2. Dump results
DumpTop10Results(outputs);
CheckFirstResult(outputs, EXPECTED_FIRST_RESULT_ID);
LOG(INFO) << "Average execution time = "
<< (end_time_us - start_time_us) / EXECUTION_REPEAT_COUNT << "us";
// 6. Teardown graph in hexagon
hexagon_control_wrapper.TeardownGraph();
// 7. Finalize hexagon
hexagon_control_wrapper.Finalize();
}
static void RunFusedGraph(const GraphDef& fused_graph_def) {
// Setup input tensor
std::vector<float> img_floats;
LoadImage(&img_floats);
LOG(INFO) << "Ioading image finished.";
const Tensor img_tensor = BuildImageTensor(img_floats);
// Setup session
std::vector<Tensor> output_tensors;
SessionOptions session_options;
session_options.env = Env::Default();
std::unique_ptr<Session> session =
std::unique_ptr<Session>(NewSession(session_options));
TF_ASSERT_OK(session->Create(fused_graph_def));
// Setup session arguments
RunOptions run_options;
run_options.set_trace_level(RunOptions::FULL_TRACE);
RunMetadata run_metadata;
std::vector<std::pair<string, tensorflow::Tensor>> input_tensors;
input_tensors.emplace_back("Mul", img_tensor);
std::vector<string> output_node_names;
output_node_names.emplace_back(REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME);
LOG(INFO) << "Run graph";
// Run inference with all node as output
TF_ASSERT_OK(session->Run(run_options, input_tensors, output_node_names, {},
&output_tensors, &run_metadata));
ASSERT_EQ(1, output_tensors.size());
const Tensor& output_tensor = output_tensors.at(0);
LOG(INFO) << "Output byte size = " << output_tensor.TotalBytes();
LOG(INFO) << "Output shape = " << output_tensor.shape().DebugString();
DumpTop10Results(
output_tensor.TotalBytes(),
reinterpret_cast<const float*>(output_tensor.flat<float>().data()));
}
static void CompareGraphTransferInfo(const GraphTransferInfo& gfi0,
const GraphTransferInfo& gfi1) {
LOG(INFO) << "(1) node count: " << gfi1.node_info_size() << ", "
<< gfi1.const_node_info_size();
// 1. check node_info
ASSERT_EQ(gfi0.node_info_size(), gfi1.node_info_size());
for (int i = 0; i < gfi0.node_info_size(); ++i) {
const GraphTransferNodeInfo& ni0 = gfi0.node_info(i);
const GraphTransferNodeInfo& ni1 = gfi1.node_info(i);
EXPECT_EQ(ni0.DebugString(), ni1.DebugString());
EXPECT_EQ(ni0.ByteSizeLong(), ni1.ByteSizeLong());
}
// 2. check const_node_info
ASSERT_EQ(gfi0.const_node_info_size(), gfi1.const_node_info_size());
for (int i = 0; i < gfi0.const_node_info_size(); ++i) {
const GraphTransferConstNodeInfo& cni0 = gfi0.const_node_info(i);
const GraphTransferConstNodeInfo& cni1 = gfi1.const_node_info(i);
ASSERT_EQ(cni0.shape_size(), cni1.shape_size());
for (int j = 0; j < cni0.shape_size(); ++j) {
EXPECT_EQ(cni0.shape(j), cni1.shape(j));
}
EXPECT_EQ(cni0.ByteSizeLong(), cni1.ByteSizeLong());
EXPECT_EQ(cni0.DebugString(), cni1.DebugString());
}
// 3. check node_input_info
ASSERT_EQ(gfi0.node_input_info_size(), gfi1.node_input_info_size());
for (int i = 0; i < gfi0.node_input_info_size(); ++i) {
const GraphTransferNodeInputInfo& nii0 = gfi0.node_input_info(i);
const GraphTransferNodeInputInfo& nii1 = gfi1.node_input_info(i);
EXPECT_EQ(nii0.ByteSizeLong(), nii1.ByteSizeLong());
EXPECT_EQ(nii0.DebugString(), nii1.DebugString());
}
// 4. check node_output_info
ASSERT_EQ(gfi0.node_output_info_size(), gfi1.node_output_info_size());
for (int i = 0; i < gfi0.node_output_info_size(); ++i) {
const GraphTransferNodeOutputInfo& noi0 = gfi0.node_output_info(i);
const GraphTransferNodeOutputInfo& noi1 = gfi1.node_output_info(i);
ASSERT_EQ(noi0.max_byte_size_size(), noi1.max_byte_size_size());
for (int j = 0; j < noi0.max_byte_size_size(); ++j) {
EXPECT_EQ(noi0.max_byte_size(j), noi1.max_byte_size(j));
}
EXPECT_EQ(noi0.ByteSizeLong(), noi1.ByteSizeLong());
EXPECT_EQ(noi0.DebugString(), noi1.DebugString());
}
// 5. check graph_input_node_info
ASSERT_EQ(gfi0.graph_input_node_info_size(),
gfi1.graph_input_node_info_size());
for (int i = 0; i < gfi0.graph_input_node_info_size(); ++i) {
const GraphTransferGraphInputNodeInfo& gini0 =
gfi0.graph_input_node_info(i);
const GraphTransferGraphInputNodeInfo& gini1 =
gfi0.graph_input_node_info(i);
EXPECT_EQ(gini0.ByteSizeLong(), gini1.ByteSizeLong());
EXPECT_EQ(gini0.DebugString(), gini1.DebugString());
}
// 6. check graph_output_node_info
ASSERT_EQ(gfi0.graph_output_node_info_size(),
gfi1.graph_output_node_info_size());
for (int i = 0; i < gfi0.graph_output_node_info_size(); ++i) {
const GraphTransferGraphOutputNodeInfo& goni0 =
gfi0.graph_output_node_info(i);
const GraphTransferGraphOutputNodeInfo& goni1 =
gfi0.graph_output_node_info(i);
EXPECT_EQ(goni0.ByteSizeLong(), goni1.ByteSizeLong());
EXPECT_EQ(goni0.DebugString(), goni1.DebugString());
}
}
// CAVEAT: This test only runs when you specify hexagon library using
// makefile.
// CAVEAT: This test is disabled by default because hexagon can keep only
// two inception graphs on memory which are allocated by other two tests.
// Memory of these graphs are not released until process is killed right now.
// TODO(satok): Figure out how to release memory on hexagon without process
// termination.
#ifdef USE_HEXAGON_LIBS
TEST(GraphTransferer,
DISABLED_RunInceptionV3OnHexagonExampleWithHexagonWrapper) {
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
CheckHexagonControllerVersion();
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
std::vector<string> output_node_names = {"softmax"};
GraphTransferer gt;
gt.EnableStrictCheckMode(false);
profile_utils::CpuUtils::EnableClockCycleProfiling();
ClockCycleProfiler prof;
prof.Start();
Status status = gt.LoadGraphFromProtoFile(
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
false, // is_text_proto
false, // shape_inference_for_unknown_shape
true // dry_run_for_unknown_shape
);
ASSERT_TRUE(status.ok()) << status;
prof.Stop();
prof.DumpStatistics("LoadGraphFromProtoFile");
std::vector<float> img_floats;
LoadImage(&img_floats);
const Tensor img_tensor = BuildImageTensor(img_floats);
RunInferenceByHexagonControlWrapper(gt, img_tensor);
}
TEST(GraphTransferer,
DISABLED_RunInceptionV3OnHexagonExampleWithHexagonWrapperQuantizedInput) {
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller "
<< "with quantized input";
CheckHexagonControllerVersion();
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_QUINT8, {1, WIDTH, HEIGHT, DEPTH}));
std::vector<string> output_node_names = {"softmax"};
GraphTransferer gt;
gt.EnableStrictCheckMode(false);
profile_utils::CpuUtils::EnableClockCycleProfiling();
ClockCycleProfiler prof;
prof.Start();
Status status = gt.LoadGraphFromProtoFile(
*ops_definitions, MODEL_WITH_QUANTIZED_INPUT_FILENAME, inputs,
output_node_names,
/*is_text_proto=*/false,
/*shape_inference_for_unknown_shape=*/false,
/*dry_run_for_unknown_shape=*/true);
ASSERT_TRUE(status.ok()) << status;
prof.Stop();
prof.DumpStatistics("LoadGraphFromProtoFile");
std::vector<float> img_floats;
LoadImage(&img_floats);
std::vector<quint8> quantized_img;
QuantizeImage(img_floats, &quantized_img);
const Tensor img_tensor = BuildQuantizedImageTensor(quantized_img);
RunInferenceByHexagonControlWrapper(gt, img_tensor);
}
TEST(GraphTransferer,
DISABLED_RunInceptionV3OnHexagonExampleWithHexagonWrapperShapeInference) {
LOG(INFO) << "Run inception v3 on hexagon with hexagon controller";
CheckHexagonControllerVersion();
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
std::vector<string> output_node_names = {"softmax"};
GraphTransferer gt;
gt.EnableStrictCheckMode(false);
profile_utils::CpuUtils::EnableClockCycleProfiling();
ClockCycleProfiler prof;
prof.Start();
Status status = gt.LoadGraphFromProtoFile(
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
false, // is_text_proto
true, // shape_inference_for_unknown_shape
false // dry_run_for_unknown_shape
);
ASSERT_TRUE(status.ok()) << status;
prof.Stop();
prof.DumpStatistics("LoadGraphFromProtoFile");
std::vector<float> img_floats;
LoadImage(&img_floats);
const Tensor img_tensor = BuildImageTensor(img_floats);
RunInferenceByHexagonControlWrapper(gt, img_tensor);
}
TEST(GraphTransferer, RunInceptionV3OnHexagonExampleWithTfRuntime) {
LOG(INFO) << "Fuse and run inception v3 on hexagon with tf runtime";
CheckHexagonControllerVersion();
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
std::vector<string> outputs = {"softmax"};
std::vector<float> img_floats;
LoadImage(&img_floats);
LOG(INFO) << "Ioading image finished.";
GraphDef graph_def;
Status status = ReadBinaryProto(Env::Default(), MODEL_FILENAME, &graph_def);
ASSERT_TRUE(status.ok());
LOG(INFO) << "Build fused graph";
GraphDef fused_graph_def = GraphTransferUtils::BuildFusedGraphDef(
HexagonOpsDefinitions::getInstance(),
REMOTE_FUSED_GRAPH_EXECUTE_NODE_NAME, inputs, outputs, &graph_def);
RunFusedGraph(fused_graph_def);
}
TEST(GraphTransferer, DISABLED_RunInceptionV3OnHexagonExampleWithFusedGraph) {
LOG(INFO) << "Run inception v3 with fused graph";
CheckHexagonControllerVersion();
GraphDef fused_graph_def;
Status status =
ReadBinaryProto(Env::Default(), FUSED_MODEL_FILENAME, &fused_graph_def);
RunFusedGraph(fused_graph_def);
}
TEST(GraphTransferer, DISABLED_CheckShapeInferencePerformance) {
CheckHexagonControllerVersion();
profile_utils::CpuUtils::EnableClockCycleProfiling();
const IRemoteFusedGraphOpsDefinitions* ops_definitions =
&HexagonOpsDefinitions::getInstance();
std::vector<std::pair<string, Tensor>> inputs;
inputs.emplace_back("Mul", Tensor(DT_FLOAT, {1, WIDTH, HEIGHT, DEPTH}));
std::vector<string> output_node_names = {"softmax"};
GraphTransferer gt0;
gt0.EnableStrictCheckMode(false);
ClockCycleProfiler prof0;
prof0.Start();
Status status = gt0.LoadGraphFromProtoFile(
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
false, // is_text_proto
false, // shape_inference_for_unknown_shape
true // dry_run_for_unknown_shape
);
const GraphTransferInfo& gfi0 = gt0.GetGraphTransferInfo();
ASSERT_TRUE(status.ok());
prof0.Stop();
prof0.DumpStatistics("Estimate shape by dryrun");
LOG(INFO) << "(0) node count: " << gfi0.node_info_size() << ", "
<< gfi0.const_node_info_size();
GraphTransferer gt1;
gt1.EnableStrictCheckMode(true);
ClockCycleProfiler prof1;
prof1.Start();
status = gt1.LoadGraphFromProtoFile(
*ops_definitions, MODEL_FILENAME, inputs, output_node_names,
false, // is_text_proto
true, // shape_inference_for_unknown_shape
false // dry_run_for_unknown_shape
);
const GraphTransferInfo& gfi1 = gt1.GetGraphTransferInfo();
ASSERT_TRUE(status.ok());
prof1.Stop();
prof1.DumpStatistics("Estiame shape by shape inference");
CompareGraphTransferInfo(gfi0, gfi1);
const RemoteFusedGraphExecuteInfo ei0 =
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi0);
const RemoteFusedGraphExecuteInfo ei1 =
BuildRemoteFusedGraphExecuteInfoWithGraphTransferInfo(gfi1);
GraphTransferInfo rgfi0;
rgfi0.ParseFromString(ei0.serialized_executor_parameters());
GraphTransferInfo rgfi1;
rgfi1.ParseFromString(ei1.serialized_executor_parameters());
CompareGraphTransferInfo(rgfi0, rgfi1);
CompareGraphTransferInfo(gfi0, rgfi0);
CompareGraphTransferInfo(gfi1, rgfi1);
}
#endif
} // namespace tensorflow

View File

@ -1,408 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
vcyou may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/core/framework/types.h"
// CAVEAT: Comment-out the following macro if you want to use experimental
// hexagon ops.
//#define ENABLE_EXPERIMENTAL_HEXNN_OPS
namespace tensorflow {
// HVX internal supported ops names
// TODO(satok): Remove this map once hexnn lib supports an API to retrieve op id
// from op name and data type
enum class HexagonOpsDefinitions::SupportedOpType {
INPUT,
OUTPUT,
NOP,
OP_CONST, /* OP_ is required to avoid compilation error on windows */
CHECK,
CLOSE_FLOAT32,
CLOSE_QINT8,
CLOSE_Q_QINT8,
CLOSE_INT32,
CLOSE_QINT32,
PPRINT_8,
PPRINT_32,
PPRINT_FLOAT,
PREFREE,
FLATTEN,
#ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS
// With Reference
QUANTIZEDCONV2D_8X8TO32,
QUANTIZEDCONV2D_8X8TO32_REF,
QUANTIZEDMATMUL_8X8TO32,
QUANTIZEDMATMUL_8X8TO32_REF,
QUANTIZEDOWNANDSHRINKRANGE_32TO8,
QUANTIZEDOWNANDSHRINKRANGE_32TO8_REF,
QUANTIZEDRELU_8,
QUANTIZEDRELU_8_REF,
QUANTIZEDRELUX_8,
QUANTIZEDRELUX_8_REF,
QUANTIZEDMAXPOOL_8,
QUANTIZEDMAXPOOL_8_REF,
QUANTIZEDAVGPOOL_8,
QUANTIZEDAVGPOOL_8_REF,
QUANTIZEDCONCAT_8,
QUANTIZEDCONCAT_8_REF,
QUANTIZEDBIASADD_8P8TO32,
QUANTIZEDBIASADD_8P8TO32_REF,
MIN_F,
MIN_F_REF,
MAX_F,
MAX_F_REF,
QUANTIZE,
QUANTIZE_REF,
DEQUANTIZE,
DEQUANTIZE_REF,
SUPERNODE_8X8P8TO8,
SUPERNODE_8X8P8TO8_REF,
QUANTIZEDFLATTEN,
SOFTMAX_F,
CONV2D_F,
MATMUL_F,
RELU_F,
RELUX_F,
AVGPOOL_F,
MAXPOOL_F,
CONCAT_F,
BIASADD_F,
LRN_F,
VARIABLE,
ASSIGN,
RESHAPE,
QUANTIZED_RESHAPE,
TANH_F,
SIGMOID_F,
SLICE_8,
SLICE_F,
QUANTIZED_SLICE_8,
ADD_F,
MUL_F,
MINIMUM_F,
MAXIMUM_F,
REQUANTIZE_32_TO_8,
REQUANTIZE_32_TO_8_REF,
REQUANTIZATION_RANGE_32,
REQUANTIZATION_RANGE_32_REF,
NEG_F,
SUB_F,
ADD_N_F,
RANGE_INT32,
RANK_INT32,
TRANSPOSE_INT32,
TRANSPOSE_F,
INSTANCE_NORM_F,
QUANTIZED_INSTANCENORM_8,
QUANTIZED_INSTANCENORM_8_REF,
SUB_INT32,
ADD_INT32,
SPLIT_F,
DEQUANTIZE_QINT32_F,
PRELU_F,
QUANTIZED_PRELU_8,
SUM_F,
PROD_F,
MUL_INT32,
LOGICAL_AND_INT32,
LOGICALOR_INT32,
LOGICAL_XOR_INT32,
SPAPE_INT32,
PACK_INT32,
MIRROR_PAD_F,
RESIZE_NEAREST_NEIGHBOR_F,
STRIDED_SLICE_INT32,
STRIDED_SLICE_F,
EXPAND_DIMS_INT32,
EXPAND_DIMS_F,
LOG_SOFTMAX_F,
SPLIT_INT32,
QUANTIZED_SPLIT_8,
DECONV_F,
QUANTIZED_DECONV_8X8TO32,
QUANTIZED_DECONV_8X8TO32_REF,
QUANTIZED_MUL_8x8to32,
QUANTIZED_MUL_8x8to32_REF,
QUANTIZED_ADD_8p8to32,
QUANTIZED_ADD_8p8to32_REF,
QUANTIZED_SIGMOID_8,
QUANTIZED_SIGMOID_8_REF,
QUANTIZED_TANH_8,
QUANTIZED_TANH_8_REF,
QUANTIZED_SOFTMAX_8,
QUANTIZED_SOFTMAX_8_REF,
QUANTIZED_LRN_8,
QUANTIZED_LRN_8_REF,
QUANTIZED_PAD2D_FRAME_8P,
QUANTIZED_PAD2D_FRAME_8P_REF,
QUANTIZED_SUB_8P8TO32,
QUANTIZED_SUB_8P8TO32_REF,
QUANTIZED_MAXIMUM_8,
QUANTIZED_MAXIMUM_8_REF,
QUANTIZED_MINIMUM_8,
QUANTIZED_MINIMUM_8_REF,
PAD_F,
SPACE_TO_BATCH_ND_F,
BATCH_TO_SPACE_ND_F,
RESIZE_BILINEAR_F,
CONCAT_V2_F,
#else
// With Reference
QUANTIZEDCONV2D_8X8TO32,
QUANTIZEDCONV2D_8X8TO32_REF,
QUANTIZEDMATMUL_8X8TO32,
QUANTIZEDMATMUL_8X8TO32_REF,
QUANTIZEDOWNANDSHRINKRANGE_32TO8,
QUANTIZEDOWNANDSHRINKRANGE_32TO8_REF,
QUANTIZEDRELU_8,
QUANTIZEDRELU_8_REF,
QUANTIZEDRELUX_8,
QUANTIZEDRELUX_8_REF,
QUANTIZEDSIGMOID_8,
QUANTIZEDSIGMOID_8_REF,
QUANTIZEDTANH_8,
QUANTIZEDTANH_8_REF,
QUANTIZEDMAXPOOL_8,
QUANTIZEDMAXPOOL_8_REF,
QUANTIZEDAVGPOOL_8,
QUANTIZEDAVGPOOL_8_REF,
QUANTIZEDCONCAT_8,
QUANTIZEDCONCAT_8_REF,
QUANTIZEDBIASADD_8P8TO32,
QUANTIZEDBIASADD_8P8TO32_REF,
QUANTIZEDSOFTMAX_8,
QUANTIZEDSOFTMAX_8_REF,
QUANTIZEDLRN_8,
QUANTIZEDLRN_8_REF,
MIN_F,
MIN_F_REF,
MAX_F,
MAX_F_REF,
QUANTIZE,
QUANTIZE_REF,
DEQUANTIZE,
DEQUANTIZE_REF,
SUPERNODE_8X8P8TO8,
SUPERNODE_8X8P8TO8_REF,
QUANTIZEDFLATTEN,
SOFTMAX_F,
CONV2D_F,
MATMUL_F,
RELU_F,
RELUX_F,
AVGPOOL_F,
MAXPOOL_F,
CONCAT_F,
BIASADD_F,
LRN_F,
VARIABLE,
ASSIGN,
RESHAPE,
QUANTIZED_RESHAPE,
TANH_F,
SIGMOID_F,
SLICE_8,
SLICE_F,
QUANTIZED_SLICE_8,
ADD_F,
MUL_F,
MINIMUM_F,
MAXIMUM_F,
REQUANTIZE_32_TO_8,
REQUANTIZE_32_TO_8_REF,
REQUANTIZATION_RANGE_32,
REQUANTIZATION_RANGE_32_REF,
NEG_F,
SUB_F,
ADD_N_F,
RANGE_INT32,
RANK_INT32,
TRANSPOSE_INT32,
TRANSPOSE_F,
INSTANCE_NORM_F,
QUANTIZED_INSTANCENORM_8,
QUANTIZED_INSTANCENORM_8_REF,
SUB_INT32,
ADD_INT32,
SPLIT_F,
DEQUANTIZE_QINT32_F,
PRELU_F,
QUANTIZED_PRELU_8,
SUM_F,
PROD_F,
MUL_INT32,
LOGICAL_AND_INT32,
LOGICALOR_INT32,
LOGICAL_XOR_INT32,
SPAPE_INT32,
PACK_INT32,
MIRROR_PAD_F,
RESIZE_NEAREST_NEIGHBOR_F,
STRIDED_SLICE_INT32,
STRIDED_SLICE_F,
EXPAND_DIMS_INT32,
EXPAND_DIMS_F,
LOG_SOFTMAX_F,
SPLIT_INT32,
QUANTIZED_SPLIT_8,
DECONV_F,
QUANTIZED_DECONV_8X8TO32,
QUANTIZED_DECONV_8X8TO32_REF,
#endif
SUPPORTED_OP_TYPE_COUNT // TERMINATOR. DO NOT REMOVE
};
/* static */ void HexagonOpsDefinitions::EmplaceOpType(
const string& op_type, const DataTypeVector& dt_vec,
const SupportedOpType supported_op_type,
std::unordered_map<string, std::vector<DataTypeToOp>>* map) {
if (map->count(op_type) <= 0) {
map->emplace(op_type, std::vector<DataTypeToOp>());
}
map->at(op_type).emplace_back(
std::forward_as_tuple(dt_vec, supported_op_type));
}
/* static */ std::unordered_map<
string, std::vector<HexagonOpsDefinitions::DataTypeToOp>>
HexagonOpsDefinitions::BuildOpNameToSocOpTypeMap() {
std::unordered_map<string, std::vector<DataTypeToOp>> op_map;
// Custom Op name
EmplaceOpType("INPUT", {}, SupportedOpType::INPUT, &op_map);
EmplaceOpType("OUTPUT", {}, SupportedOpType::OUTPUT, &op_map);
EmplaceOpType("NoOp", {}, SupportedOpType::NOP, &op_map);
// Special op type for hexagon
EmplaceOpType("FLATTEN", {}, SupportedOpType::FLATTEN, &op_map);
// Tensorflow op name
// CAVEAT: Keep order of SupportedOpType
EmplaceOpType("Identity", {}, SupportedOpType::NOP, &op_map);
EmplaceOpType("Placeholder", {}, SupportedOpType::NOP, &op_map);
EmplaceOpType("Const", {}, SupportedOpType::OP_CONST, &op_map);
EmplaceOpType("QuantizedConv2D", {}, SupportedOpType::QUANTIZEDCONV2D_8X8TO32,
&op_map);
EmplaceOpType("QuantizedMatMul", {}, SupportedOpType::QUANTIZEDMATMUL_8X8TO32,
&op_map);
EmplaceOpType("QuantizeDownAndShrinkRange", {},
SupportedOpType::QUANTIZEDOWNANDSHRINKRANGE_32TO8, &op_map);
EmplaceOpType("QuantizedRelu", {}, SupportedOpType::QUANTIZEDRELU_8, &op_map);
EmplaceOpType("QuantizedReluX", {}, SupportedOpType::QUANTIZEDRELUX_8,
&op_map);
EmplaceOpType("QuantizedMaxPool", {}, SupportedOpType::QUANTIZEDMAXPOOL_8,
&op_map);
EmplaceOpType("QuantizedAvgPool", {}, SupportedOpType::QUANTIZEDAVGPOOL_8,
&op_map);
EmplaceOpType("QuantizedConcat", {}, SupportedOpType::QUANTIZEDCONCAT_8,
&op_map);
EmplaceOpType("QuantizedBiasAdd", {},
SupportedOpType::QUANTIZEDBIASADD_8P8TO32, &op_map);
EmplaceOpType("Min", {}, SupportedOpType::MIN_F, &op_map);
EmplaceOpType("Max", {}, SupportedOpType::MAX_F, &op_map);
EmplaceOpType("QuantizeV2", {}, SupportedOpType::QUANTIZE, &op_map);
EmplaceOpType("Dequantize", {}, SupportedOpType::DEQUANTIZE, &op_map);
EmplaceOpType("Softmax", {}, SupportedOpType::SOFTMAX_F, &op_map);
EmplaceOpType("Reshape", {}, SupportedOpType::RESHAPE, &op_map);
EmplaceOpType("QuantizedReshape", {}, SupportedOpType::QUANTIZED_RESHAPE,
&op_map);
EmplaceOpType("Sigmoid", {}, SupportedOpType::SIGMOID_F, &op_map);
EmplaceOpType("Slice", {}, SupportedOpType::SLICE_F, &op_map);
EmplaceOpType("Add", {}, SupportedOpType::ADD_F, &op_map);
EmplaceOpType("Mul", {}, SupportedOpType::MUL_F, &op_map);
EmplaceOpType("Requantize", {}, SupportedOpType::REQUANTIZE_32_TO_8, &op_map);
EmplaceOpType("RequantizationRange", {},
SupportedOpType::REQUANTIZATION_RANGE_32, &op_map);
EmplaceOpType("Sub", {}, SupportedOpType::SUB_F, &op_map);
EmplaceOpType("Pack", {}, SupportedOpType::PACK_INT32, &op_map);
EmplaceOpType("StridedSlice", {}, SupportedOpType::STRIDED_SLICE_F, &op_map);
EmplaceOpType("ExpandDims", {}, SupportedOpType::EXPAND_DIMS_F, &op_map);
#ifdef ENABLE_EXPERIMENTAL_HEXNN_OPS
EmplaceOpType("QuantizedMul", {}, SupportedOpType::QUANTIZED_MUL_8x8to32,
&op_map);
EmplaceOpType("QuantizedAdd", {}, SupportedOpType::QUANTIZED_ADD_8p8to32,
&op_map);
EmplaceOpType("Pad", {}, SupportedOpType::PAD_F, &op_map);
EmplaceOpType("SpaceToBatchND", {}, SupportedOpType::SPACE_TO_BATCH_ND_F,
&op_map),
EmplaceOpType("BatchToSpaceND", {}, SupportedOpType::BATCH_TO_SPACE_ND_F,
&op_map);
EmplaceOpType("ResizeBilinear", {}, SupportedOpType::RESIZE_BILINEAR_F,
&op_map);
EmplaceOpType("ConcatV2", {}, SupportedOpType::CONCAT_V2_F, &op_map);
EmplaceOpType("Conv2DBackpropInput", {}, SupportedOpType::DECONV_F, &op_map);
EmplaceOpType("Tanh", {}, SupportedOpType::TANH_F, &op_map);
EmplaceOpType("Split", {}, SupportedOpType::SPLIT_F, &op_map);
EmplaceOpType("Transpose", {}, SupportedOpType::TRANSPOSE_F, &op_map);
EmplaceOpType("Concat", {}, SupportedOpType::CONCAT_F, &op_map);
#endif
return op_map;
};
HexagonOpsDefinitions::HexagonOpsDefinitions()
: op_name_to_soc_op_type_map_(BuildOpNameToSocOpTypeMap()) {}
/* static */ const IRemoteFusedGraphOpsDefinitions&
HexagonOpsDefinitions::getInstance() {
const static HexagonOpsDefinitions instance{};
return instance;
}
int HexagonOpsDefinitions::GetTotalOpsCount() const {
return static_cast<int>(SupportedOpType::SUPPORTED_OP_TYPE_COUNT);
}
int HexagonOpsDefinitions::GetOpIdFor(const string& op_type,
const DataTypeVector& dt_vec) const {
if (op_name_to_soc_op_type_map_.count(op_type) > 0) {
const std::vector<DataTypeToOp>& dt_to_op_vec =
op_name_to_soc_op_type_map_.at(op_type);
CHECK(!dt_to_op_vec.empty());
// If argument DataType is empty, return the first entry.
if (dt_vec.empty()) {
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
}
// If there is only one op_id registered for empty op_vec, we assume
// that the op supports any data types.
if (dt_to_op_vec.size() == 1 && std::get<0>(dt_to_op_vec.front()).empty()) {
return static_cast<int>(std::get<1>(dt_to_op_vec.front()));
}
for (const DataTypeToOp& data_type_to_op : dt_to_op_vec) {
if (std::get<0>(data_type_to_op) == dt_vec) {
return static_cast<int>(std::get<1>(data_type_to_op));
}
}
}
return IRemoteFusedGraphOpsDefinitions::INVALID_OP_ID;
}
} // namespace tensorflow

View File

@ -1,58 +0,0 @@
/* Copyright 2016 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
vcyou may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
#define TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_
#include <unordered_map>
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/i_remote_fused_graph_ops_definitions.h"
#include "tensorflow/core/platform/macros.h"
namespace tensorflow {
// HexagonOpsDefinitions provides ops definitions supported in hexagon library
// TODO(satok): add a functionality to call functions in hexagon library
class HexagonOpsDefinitions final : public IRemoteFusedGraphOpsDefinitions {
public:
static const IRemoteFusedGraphOpsDefinitions& getInstance();
int GetTotalOpsCount() const final;
int GetOpIdFor(const string& op_type, const DataTypeVector& dt) const final;
private:
enum class SupportedOpType;
using DataTypeToOp = std::tuple<DataTypeVector, SupportedOpType>;
HexagonOpsDefinitions();
static void EmplaceOpType(
const string& op_type, const DataTypeVector& dt_vec,
const SupportedOpType supported_op_type,
std::unordered_map<string, std::vector<DataTypeToOp>>* map);
static std::unordered_map<string, std::vector<DataTypeToOp>>
BuildOpNameToSocOpTypeMap();
const std::unordered_map<string, std::vector<DataTypeToOp>>
op_name_to_soc_op_type_map_;
TF_DISALLOW_COPY_AND_ASSIGN(HexagonOpsDefinitions);
};
} // namespace tensorflow
#endif // TENSORFLOW_CORE_KERNELS_HEXAGON_HEXAGON_OPS_DEFINITIONS_H_

View File

@ -1,34 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/hexagon/hexagon_control_wrapper.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
namespace tensorflow {
namespace hexagon_remote_fused_graph_executor_build {
Status BuildRemoteFusedGraphExecutor(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor) {
executor->reset(new HexagonControlWrapper());
return Status::OK();
}
static RemoteFusedGraphExecuteUtils::ExecutorBuildRegistrar
k_hexagon_remote_fused_graph_executor_build(
HexagonControlWrapper::REMOTE_FUSED_GRAPH_EXECUTOR_NAME,
BuildRemoteFusedGraphExecutor);
} // namespace hexagon_remote_fused_graph_executor_build
} // namespace tensorflow

View File

@ -1,42 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/core/kernels/i_remote_fused_graph_executor.h"
#include "tensorflow/core/kernels/remote_fused_graph_execute_utils.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
namespace hexagon_remote_fused_graph_executor_build {
Status BuildRemoteFusedGraphExecutor(
std::unique_ptr<IRemoteFusedGraphExecutor>* executor);
namespace {
TEST(HexagonBuildRemoteFusedGraphExecutorTest, BasicRun) {
std::unique_ptr<IRemoteFusedGraphExecutor> executor;
ASSERT_FALSE(static_cast<bool>(executor));
TF_ASSERT_OK(BuildRemoteFusedGraphExecutor(&executor));
ASSERT_TRUE(static_cast<bool>(executor));
ASSERT_NE(RemoteFusedGraphExecuteUtils::GetExecutorBuildFunc(
"build_hexagon_remote_fused_graph_executor"),
nullptr);
}
} // namespace
} // namespace hexagon_remote_fused_graph_executor_build
} // namespace tensorflow

View File

@ -1,94 +0,0 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// Wraps the hexagon rewriter in a transform so it can be used as part of the
// graph transform tool.
// A usage example, based on the Image Understanding pipeline:
/*
bazel build tensorflow/tools/graph_transforms:transform_graph
bazel-bin/tensorflow/tools/graph_transforms/transform_graph \
--in_graph=/tmp/tensorflow_inception_v3_stripped_optimized_quantized.pb \
--out_graph=\
/tmp/tensorflow_inception_v3_stripped_optimized_quantized_fused_hexagon.pb \
--inputs='Mul' \
--outputs='softmax' \
--transforms='\
rewrite_quantized_stripped_model_for_hexagon(
input_shape0="1,299,299,3" \
input_type0="float" \
)'
*/
#include "tensorflow/core/kernels/hexagon/graph_transfer_utils.h"
#include "tensorflow/core/kernels/hexagon/hexagon_ops_definitions.h"
#include "tensorflow/tools/graph_transforms/transform_utils.h"
namespace tensorflow {
namespace graph_transforms {
constexpr const char* const INPUT_SHAPE_PREFIX = "input_shape";
constexpr const char* const INPUT_TYPE_PREFIX = "input_type";
Status RewriteQuantizedStrippedModelForHexagon(
const GraphDef& input_graph_def, const TransformFuncContext& context,
GraphDef* output_graph_def) {
LOG(INFO) << "Transforming quantized stripped model to a remote fused "
"graph execute op...";
std::vector<std::pair<string, Tensor>> inputs;
std::vector<string> outputs;
for (auto i = 0; static_cast<size_t>(i) < context.input_names.size(); ++i) {
const string& input_name = context.input_names.at(i);
// Get input shape
string shape_string;
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
INPUT_SHAPE_PREFIX + std::to_string(i), "", &shape_string));
std::vector<string> split_shape = str_util::Split(shape_string, ',');
std::vector<int64> dims;
for (const string& dim : split_shape) {
int64 tmp;
CHECK(strings::safe_strto64(dim, &tmp));
dims.push_back(tmp);
}
// Get input data type
string data_type_string;
TF_RETURN_IF_ERROR(context.GetOneStringParameter(
INPUT_TYPE_PREFIX + std::to_string(i), "", &data_type_string));
DataType data_type;
CHECK(DataTypeFromString(data_type_string, &data_type))
<< "\"" << data_type_string << "\" was an invalid type";
LOG(INFO) << "Input(" << i << "): name = " << input_name
<< ", shape = " << shape_string
<< ", type = " << data_type_string;
inputs.emplace_back(input_name, Tensor(data_type, TensorShape(dims)));
}
for (const string& output_name : context.output_names) {
outputs.emplace_back(output_name);
}
GraphDef mutable_input_graph_def = input_graph_def;
*output_graph_def = GraphTransferUtils::BuildFusedGraphDef(
HexagonOpsDefinitions::getInstance(), "remote_fused_graph_execute_node",
inputs, outputs, &mutable_input_graph_def);
return Status::OK();
}
REGISTER_GRAPH_TRANSFORM("rewrite_quantized_stripped_model_for_hexagon",
RewriteQuantizedStrippedModelForHexagon);
} // namespace graph_transforms
} // namespace tensorflow

Some files were not shown because too many files have changed in this diff Show More