Merge remote-tracking branch 'upstream/master' into arc_mli_evaltensor_porting_conv

This commit is contained in:
gerbauz 2020-11-12 12:05:48 +03:00
commit ae7944f9c0
672 changed files with 21406 additions and 9066 deletions

View File

@ -45,6 +45,11 @@
* Removed deprecated `Interpreter::UseNNAPI(bool)` C++ API.
* Use `NnApiDelegate()` and related delegate configuration methods
directly.
* 16 bits quantization
* Added int16x8 support for ABS, REDUCE_MAX and REDUCE_MIN operators.
* Added support for saved model's session initializer through
`TFLiteConverter.from_saved_model`.
* TF Core:
* Corrected higher-order gradients of control flow constructs (`tf.cond`,
`tf.while_loop`, and compositions like `tf.foldl`) computed with

View File

@ -199,6 +199,7 @@ tf_cuda_library(
"//tensorflow/core:portable_tensorflow_lib_lite",
],
"//conditions:default": [
":logging",
":tf_status",
":tf_tensor",
"@com_google_absl//absl/strings",

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/tensor_interface.h"
#include "tensorflow/core/framework/function.pb.h"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/platform/tstring.h"
@ -138,8 +139,8 @@ class ImmediateExecutionContext : public AbstractContext {
}
//===--------------------------------------------------------------------===//
// Following are legacy features in TF Eager Runtime.
// TODO(tf-runtime): Figure out a way to deprecate following features after
// Following are features in current TF Eager Runtime.
// TODO(tfrt-devs): Figure out a way to deprecate following features after
// migrated to TFRT.
//===--------------------------------------------------------------------===//
// Clear pending nodes in thread executors and kernel caches.
@ -157,6 +158,22 @@ class ImmediateExecutionContext : public AbstractContext {
// Update the Eager Executor for current thread.
virtual void SetExecutorForThread(EagerExecutor* executor) = 0;
//===--------------------------------------------------------------------===//
// Following are helper functions to assist integrating TFRT with current
// TF eager runtime.
// TODO(b/172877902): These helper functions are currently used to support
// PyFuncOp on TFRT, and might be useful for ops that directly use low
// level TF APIs. Remove/replace the following functions when TFRT native
// ops are implemented.
//===--------------------------------------------------------------------===//
// Create an abstract tensor handle from tensorflow::Tensor.
virtual ImmediateExecutionTensorHandle* CreateLocalHandleFromTFTensor(
tensorflow::Tensor& t, const char* d_name) = 0;
// Convert a TFRT TensorHandle to tensorflow::TensorHandle.
virtual ImmediateExecutionTensorHandle* TFTensorHandleFromInterface(
ImmediateExecutionTensorHandle* handle) = 0;
protected:
explicit ImmediateExecutionContext(AbstractContextKind kind)
: AbstractContext(kind) {}

View File

@ -76,6 +76,7 @@ cc_library(
"//tensorflow/c:c_api",
"//tensorflow/c/eager:c_api",
"//tensorflow/c/eager:c_api_experimental",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"@com_google_absl//absl/types:optional",
"@com_google_absl//absl/types:span",

View File

@ -328,6 +328,17 @@ ParallelDevice::Execute(TFE_Context* context,
const char* operation_name,
const TFE_OpAttrs* attributes, int expected_max_outputs,
TF_Status* status) const {
std::vector<PartialTensorShape> expected_output_shapes(expected_max_outputs);
return Execute(context, inputs, operation_name, attributes,
expected_output_shapes, status);
}
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>>
ParallelDevice::Execute(
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
const std::vector<PartialTensorShape>& expected_output_shapes,
TF_Status* status) const {
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> result;
// Compute per-device per-output tensors
std::vector<std::vector<TensorHandlePtr>> per_device_output_tensors;
@ -344,7 +355,7 @@ ParallelDevice::Execute(TFE_Context* context,
}
device_thread->StartExecute(context, operation_name,
std::move(device_inputs), attributes,
expected_max_outputs);
expected_output_shapes.size());
}
StatusPtr first_bad_status(nullptr);
for (int device_index = 0; device_index < underlying_devices_.size();
@ -386,8 +397,15 @@ ParallelDevice::Execute(TFE_Context* context,
for (int j = 0; j < underlying_devices_.size(); ++j) {
components.push_back(std::move(per_device_output_tensors[j][i]));
}
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
if (expected_output_shapes[i].IsFullyDefined()) {
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components),
absl::Span<const int64>(expected_output_shapes[i].dim_sizes()),
status));
} else {
per_device_outputs.push_back(ParallelTensor::FromTensorHandles(
*this, std::move(components), status));
}
if (TF_GetCode(status) != TF_OK) return result;
}
result.emplace(std::move(per_device_outputs));
@ -396,9 +414,27 @@ ParallelDevice::Execute(TFE_Context* context,
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
TF_Status* status) {
TF_DataType dtype = TFE_TensorHandleDataType(components[0].get());
std::vector<int64_t> shape(
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
for (TensorHandlePtr& component : components) {
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
return std::unique_ptr<ParallelTensor>(
new ParallelTensor(parallel_device, std::move(components), shape, dtype));
}
std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status) {
std::vector<int64> shape(
TFE_TensorHandleNumDims(components[0].get(), status));
if (TF_GetCode(status) != TF_OK) return nullptr;
for (int i = 0; i < shape.size(); ++i) {
@ -406,11 +442,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
if (TF_GetCode(status) != TF_OK) return nullptr;
}
// Verify that the TensorHandle's shape and dtype match all of the component
// shapes and dtypes.
// Verify that the TensorHandle's shape matches all of the component shapes.
for (TensorHandlePtr& component : components) {
for (int i = 0; i < shape.size(); ++i) {
int64_t tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
int64 tensor_dim = TFE_TensorHandleDim(component.get(), i, status);
if (TF_GetCode(status) != TF_OK) return nullptr;
if (tensor_dim != shape[i]) {
// TODO(allenl): Allow shapes to differ.
@ -419,17 +454,10 @@ std::unique_ptr<ParallelTensor> ParallelTensor::FromTensorHandles(
"the same shape");
return nullptr;
}
if (TFE_TensorHandleDataType(component.get()) != dtype) {
TF_SetStatus(status, TF_INTERNAL,
"Components of a ParallelTensor must all have "
"the same dtype");
return nullptr;
}
}
}
return std::unique_ptr<ParallelTensor>(new ParallelTensor(
parallel_device, std::move(components), std::move(shape), dtype));
return FromTensorHandles(parallel_device, std::move(components),
absl::Span<const int64>(shape), status);
}
} // namespace parallel_device

View File

@ -26,6 +26,7 @@ limitations under the License.
#include "tensorflow/c/c_api.h"
#include "tensorflow/c/eager/c_api.h"
#include "tensorflow/c/eager/c_api_experimental.h"
#include "tensorflow/core/framework/tensor_shape.h"
namespace tensorflow {
namespace parallel_device {
@ -93,6 +94,15 @@ class ParallelDevice {
const char* operation_name, const TFE_OpAttrs* attributes,
int expected_max_outputs, TF_Status* status) const;
// Accepts inferred shapes for outputs, which if fully defined will avoid
// querying the shapes of the underlying TensorHandles. This allows async
// computation to continue without blocking.
absl::optional<std::vector<std::unique_ptr<ParallelTensor>>> Execute(
TFE_Context* context, const std::vector<ParallelTensor*>& inputs,
const char* operation_name, const TFE_OpAttrs* attributes,
const std::vector<PartialTensorShape>& expected_output_shapes,
TF_Status* status) const;
private:
// A sequence of device names, indicating which devices replicated operations
// are forwarded to.
@ -117,10 +127,15 @@ class ParallelDevice {
class ParallelTensor {
public:
// Construct a ParallelTensor from TensorHandles placed on the component
// devices of a ParallelDevice.
// devices of a ParallelDevice. Inspects `components` to determine a shape.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, TF_Status* status);
// Uses the provided shape without additional checks, which avoids blocking.
static std::unique_ptr<ParallelTensor> FromTensorHandles(
const ParallelDevice& parallel_device,
std::vector<TensorHandlePtr> components, absl::Span<const int64> shape,
TF_Status* status);
size_t num_tensors() const { return tensors_.size(); }
TFE_TensorHandle* tensor(size_t index) const { return tensors_[index].get(); }
@ -132,10 +147,10 @@ class ParallelTensor {
private:
ParallelTensor(const ParallelDevice& device,
std::vector<TensorHandlePtr> tensors,
std::vector<int64_t> shape, const TF_DataType dtype)
absl::Span<const int64> shape, const TF_DataType dtype)
: device_(device),
tensors_(std::move(tensors)),
shape_(std::move(shape)),
shape_(shape.begin(), shape.end()),
dtype_(dtype) {}
const ParallelDevice& device_;

View File

@ -80,5 +80,41 @@ TEST(PARALLEL_DEVICE_LIB, TestOpWithError) {
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
}
TEST(PARALLEL_DEVICE_LIB, TestExplicitOutputShape) {
std::unique_ptr<TF_Status, decltype(&TF_DeleteStatus)> status(
TF_NewStatus(), TF_DeleteStatus);
std::unique_ptr<TFE_ContextOptions, decltype(&TFE_DeleteContextOptions)> opts(
TFE_NewContextOptions(), TFE_DeleteContextOptions);
std::unique_ptr<TF_Buffer, decltype(&TF_DeleteBuffer)> config(
TF_CreateConfig(
/*xla*/ false,
/* gpu_memory_allow_growth */ true, /* num_cpu_devices */
2),
TF_DeleteBuffer);
TFE_ContextOptionsSetConfig(opts.get(), config->data, config->length,
status.get());
std::unique_ptr<TFE_Context, decltype(&TFE_DeleteContext)> context(
TFE_NewContext(opts.get(), status.get()), TFE_DeleteContext);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
std::vector<std::string> devices{
"/job:localhost/replica:0/task:0/device:CPU:0",
"/job:localhost/replica:0/task:0/device:CPU:1"};
ParallelDevice parallel_device(std::move(devices));
std::unique_ptr<TFE_Op, decltype(&TFE_DeleteOp)> handle_op(
TFE_NewOp(context.get(), "VarHandleOp", status.get()), TFE_DeleteOp);
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
TFE_OpSetAttrType(handle_op.get(), "dtype", TF_FLOAT);
TFE_OpSetAttrShape(handle_op.get(), "shape", /*dims=*/nullptr, /*num_dims=*/0,
status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
auto outputs = parallel_device.Execute(
context.get(), std::vector<ParallelTensor*>(), "VarHandleOp",
TFE_OpGetAttrs(handle_op.get()), {PartialTensorShape({})}, status.get());
ASSERT_TRUE(TF_GetCode(status.get()) == TF_OK) << TF_Message(status.get());
const std::vector<std::unique_ptr<ParallelTensor>>& handles = *outputs;
EXPECT_EQ(0, handles[0]->shape().size());
}
} // namespace parallel_device
} // namespace tensorflow

View File

@ -508,7 +508,7 @@ cc_library(
":flags",
":jit_compilation_passes",
"//tensorflow/compiler/jit/kernels:xla_ops_no_jit_rewrite_registration",
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"//tensorflow/compiler/tf2xla:mlir_bridge_pass",
"//tensorflow/compiler/tf2xla:xla_compiler",
"//tensorflow/compiler/tf2xla:xla_op_registry",
"//tensorflow/core:core_cpu_internal",

View File

@ -115,7 +115,7 @@ xla::StatusOr<std::string> GetCompilerIr(
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_arg_indices, inputs, variable_infos);
constant_arg_indices, inputs, variable_infos, dev);
TF_RETURN_IF_ERROR(args.status());
switch (stage) {

View File

@ -206,8 +206,9 @@ static Status CompileToLocalExecutable(
may_alias_resource_update;
xla::StatusOr<std::vector<XlaCompiler::Argument>> args =
XlaComputationLaunchContext::BuildXlaCompilerArguments(constants, inputs,
variable_infos);
XlaComputationLaunchContext::BuildXlaCompilerArguments(
constants, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
return cache->Compile(options, function, *args, compile_options,
lazy ? XlaCompilationCache::CompileMode::kLazy
@ -246,8 +247,6 @@ void XlaLocalLaunchBase::Compute(OpKernelContext* ctx) {
se::Stream* stream =
ctx->op_device_context() ? ctx->op_device_context()->stream() : nullptr;
VLOG(1) << "Executing XLA Computation...";
absl::optional<se::TfAllocatorAdapter> tf_allocator_adapter;
se::DeviceMemoryAllocator* allocator = GetAllocator(
&tf_allocator_adapter, ctx->device(),

View File

@ -140,6 +140,7 @@ XlaCompilationCache::BuildSignature(
for (const XlaCompiler::Argument& arg : args) {
switch (arg.kind) {
case XlaCompiler::Argument::kConstant:
case XlaCompiler::Argument::kConstantResource:
signature.arg_values.push_back(arg.constant_value);
break;
case XlaCompiler::Argument::kParameter:
@ -288,7 +289,7 @@ Status XlaCompilationCache::CompileSingleOp(
const ConfigProto* config = ctx->function_library()->config_proto();
// TODO(b/171039585): Support tf.VarIsInitializedOp using MLIR.
bool use_mlir = config &&
GetMlirBridgeRolloutPolicy(*config) ==
GetMlirBridgeRolloutPolicy(*graph, *config) ==
MlirBridgeRolloutPolicy::kEnabledByUser &&
node_def.op() != "VarIsInitializedOp";
#ifdef LIBTPU_ON_GCE

View File

@ -153,7 +153,8 @@ Status XlaCompileOnDemandOp::Compile(
ctx, variables_indices, variable_infos, variable_args));
args = XlaComputationLaunchContext::BuildXlaCompilerArguments(
constant_input_indices, inputs, variable_infos);
constant_input_indices, inputs, variable_infos,
static_cast<Device*>(ctx->device()));
TF_RETURN_IF_ERROR(args.status());
}

View File

@ -14,7 +14,6 @@ limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/jit/xla_kernel_creator.h"
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
#include "absl/memory/memory.h"
#include "absl/strings/str_cat.h"
#include "absl/strings/str_format.h"
@ -23,6 +22,7 @@ limitations under the License.
#include "tensorflow/compiler/jit/flags.h"
#include "tensorflow/compiler/jit/kernels/xla_ops.h"
#include "tensorflow/compiler/tf2xla/const_analysis.h"
#include "tensorflow/compiler/tf2xla/mlir_bridge_pass.h"
#include "tensorflow/compiler/tf2xla/xla_op_registry.h"
#include "tensorflow/core/common_runtime/function.h"
#include "tensorflow/core/framework/node_def_builder.h"
@ -89,10 +89,21 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
// Make sure that kernels have been registered on the JIT device.
XlaOpRegistry::RegisterCompilationKernels();
// Get function body, constant args, and resource args.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
// Only check for compilability if the MLIR bridge is not enabled.
MlirBridgeRolloutPolicy policy = GetMlirBridgeRolloutPolicy(absl::nullopt);
if (policy == MlirBridgeRolloutPolicy::kDisabledByUser ||
policy == MlirBridgeRolloutPolicy::kDisabledAfterGraphAnalysis) {
absl::optional<ConfigProto> config_proto;
if (flr->config_proto()) {
config_proto = *flr->config_proto();
}
if (!IsMlirBridgePassEnabled(*fbody->graph, config_proto)) {
RecursiveCompilabilityChecker::UncompilableNodesMap uncompilable_nodes_map;
if (!IsCompilable(flr, node_def, &uncompilable_nodes_map)) {
std::vector<RecursiveCompilabilityChecker::UncompilableNodeInfo>
@ -121,15 +132,6 @@ static Status CreateXlaKernel(FunctionLibraryRuntime* flr,
}
}
// Get function body, constant args, and resource args.
NameAttrList function;
TF_RETURN_IF_ERROR(NameAndAttrsFromFunctionCall(node_def, &function));
const FunctionBody* fbody = nullptr;
std::vector<int> constant_arg_indices;
std::vector<int> resource_arg_indices;
TF_RETURN_IF_ERROR(GetBodyAndConstantsAndResources(
flr, function, &fbody, &constant_arg_indices, &resource_arg_indices));
MemoryTypeVector input_memory_types =
GetInputMemoryTypes(fbody, constant_arg_indices, resource_arg_indices);
MemoryTypeVector output_memory_types = GetOutputMemoryTypes(fbody);

View File

@ -564,11 +564,26 @@ xla::StatusOr<std::vector<XlaCompiler::Argument>>
XlaComputationLaunchContext::BuildXlaCompilerArguments(
absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args) {
absl::Span<VariableInfo const> variable_args, Device* device) {
CHECK(absl::c_is_sorted(must_be_constant_idxs));
std::vector<XlaCompiler::Argument> out;
out.resize(inputs.size());
// TODO(cheshire): Avoid duplication with framework/op_kernel.h
DeviceContext* device_context = nullptr;
TF_RETURN_IF_ERROR(device->TryGetDeviceContext(&device_context));
bool using_default_context = false;
auto cleanup = xla::MakeCleanup([&] {
if (device_context != nullptr && !using_default_context) {
device_context->Unref();
}
});
if (device_context == nullptr) {
using_default_context = true;
auto* dev_info = device->tensorflow_gpu_device_info();
if (dev_info) device_context = dev_info->default_context;
}
absl::flat_hash_map<int, const VariableInfo*> variable_info_lookup;
for (const VariableInfo& info : variable_args) {
CHECK(!info.var() || info.lock_held())
@ -581,18 +596,7 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
const Tensor* input = inputs[input_num];
XlaCompiler::Argument& arg = out[input_num];
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
// Handles compile-time constants.
// TODO(b/157241314): Support constants located in resource variables.
TF_RET_CHECK(input->dtype() != DT_RESOURCE)
<< "tf2xla bridge does not support must-be-constants located in "
"resource variables; try moving them to a tensor";
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype();
arg.shape = input->shape();
arg.constant_value = *input;
} else if (variable_info_lookup.count(input_num)) {
if (variable_info_lookup.count(input_num)) {
// Handles resource variables.
TF_RET_CHECK(input->dtype() == DT_RESOURCE);
const VariableInfo& variable = *variable_info_lookup[input_num];
@ -613,6 +617,25 @@ XlaComputationLaunchContext::BuildXlaCompilerArguments(
arg.type = DT_INVALID;
arg.shape = TensorShape();
}
if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
TF_RET_CHECK(variable.var() && variable.var()->is_initialized);
const Tensor* value = variable.var()->tensor();
Tensor value_on_host(value->dtype(), value->shape());
if (!device_context) {
value_on_host = *value;
} else {
TF_RETURN_IF_ERROR(device_context->CopyDeviceTensorToCPUSync(
value, "", device, &value_on_host));
}
arg.kind = XlaCompiler::Argument::kConstantResource;
arg.constant_value = value_on_host;
}
} else if (absl::c_binary_search(must_be_constant_idxs, input_num)) {
arg.kind = XlaCompiler::Argument::kConstant;
arg.type = input->dtype();
arg.shape = input->shape();
arg.constant_value = *input;
} else {
// Normal inputs.
TF_RET_CHECK(input->dtype() != DT_RESOURCE);

View File

@ -143,7 +143,8 @@ class XlaComputationLaunchContext {
static xla::StatusOr<std::vector<XlaCompiler::Argument>>
BuildXlaCompilerArguments(absl::Span<int const> must_be_constant_idxs,
absl::Span<const Tensor* const> inputs,
absl::Span<VariableInfo const> variable_args);
absl::Span<VariableInfo const> variable_args,
Device* device);
// Add all inputs within `ctx` as XLA arguments (returned by arguments()).
// `variables` is a map from TensorFlow argument number to resource variable.

View File

@ -3,7 +3,11 @@
load("//tensorflow:tensorflow.bzl", "filegroup")
load("//tensorflow/core/platform:rules_cc.bzl", "cc_library")
load("//tensorflow:tensorflow.bzl", "tf_cc_binary")
load(
"//tensorflow:tensorflow.bzl",
"tf_cc_binary",
"tf_cc_test",
)
package(
default_visibility = [
@ -126,12 +130,14 @@ cc_library(
srcs = ["mlir_graph_optimization_pass.cc"],
hdrs = ["mlir_graph_optimization_pass.h"],
deps = [
"//tensorflow/compiler/mlir:mlir_bridge_rollout_policy",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:convert_graphdef",
"//tensorflow/compiler/mlir/tensorflow:device_util",
"//tensorflow/compiler/mlir/tensorflow:dump_mlir_util",
"//tensorflow/compiler/mlir/tensorflow:mlir_roundtrip_flags",
"//tensorflow/core:core_cpu",
"//tensorflow/core:lib",
"@com_google_absl//absl/container:flat_hash_set",
"@llvm-project//llvm:Support",
"@llvm-project//mlir:IR",
@ -198,11 +204,22 @@ cc_library(
visibility = ["//visibility:public"],
deps = [
"//tensorflow/compiler/jit:flags",
"//tensorflow/core:graph",
"//tensorflow/core:protos_all_cc",
"@com_google_absl//absl/types:optional",
],
)
tf_cc_test(
name = "mlir_graph_optimization_pass_test",
srcs = ["mlir_graph_optimization_pass_test.cc"],
deps = [
":mlir_graph_optimization_pass",
"//tensorflow/core:test",
"//tensorflow/core:test_main",
],
)
filegroup(
name = "litfiles",
srcs = glob(["runlit*py"]),

View File

@ -87,6 +87,32 @@ Value InsertAlloc(Location loc, OpResult result,
return alloc;
}
/// Converts the results of the operation `op` to memref types and append them
/// to the `results` vector.
LogicalResult ConvertResults(Operation* op, SmallVectorImpl<Value>& results,
ConversionPatternRewriter& rewriter) {
for (auto result : llvm::enumerate(op->getResults())) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) return failure();
if (resultType.hasStaticShape()) {
results.push_back(InsertAlloc(op->getLoc(), result.value(), &rewriter));
continue;
}
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
SmallVector<Value, 1> results_shape;
auto status = shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
results.push_back(
InsertDynamicAllocAndDealloc(op->getLoc(), result.value(),
results_shape[result.index()], &rewriter));
}
return success();
}
template <typename HloOpTy>
class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
public:
@ -95,29 +121,8 @@ class HloToLhloOpConverter : public BaseOpConversion<HloOpTy> {
HloOpTy hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 4> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
SmallVector<Value, 1> results_shape;
auto status =
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape);
if (failed(status)) return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
rewriter.create<mhlo::HloToLhloOp<HloOpTy>>(op->getLoc(), llvm::None,
buffer_args, op->getAttrs());
rewriter.replaceOp(
@ -139,28 +144,8 @@ class HloToLhloOpConverter<mhlo::DotOp> : public BaseOpConversion<mhlo::DotOp> {
mhlo::DotOp hloOp, ArrayRef<Value> operands,
ConversionPatternRewriter& rewriter) const final {
Operation* op = hloOp.getOperation();
const auto& original_results = op->getResults();
SmallVector<Value, 2> buffer_args(operands.begin(), operands.end());
for (auto result : llvm::enumerate(original_results)) {
RankedTensorType resultType =
result.value().getType().dyn_cast<RankedTensorType>();
if (!resultType) {
return failure();
}
if (resultType.hasStaticShape()) {
buffer_args.push_back(
InsertAlloc(op->getLoc(), result.value(), &rewriter));
} else {
SmallVector<Value, 1> results_shape;
auto shape_type_op = dyn_cast<InferShapedTypeOpInterface>(op);
if (!shape_type_op) return failure();
if (failed(
shape_type_op.reifyReturnTypeShapes(rewriter, results_shape)))
return failure();
buffer_args.push_back(InsertDynamicAllocAndDealloc(
op->getLoc(), result.value(), results_shape.front(), &rewriter));
}
}
if (failed(ConvertResults(op, buffer_args, rewriter))) return failure();
// TODO(silvasean): Move this helper to MLIR core.
auto make_elements_attr = [&rewriter](ArrayRef<int64_t> integers) {
@ -194,8 +179,7 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
Value transformed_operand =
InsertDynamicMemrefCastOp(op, operands.front(), &rewriter);
rewriter.create<lmhlo::BroadcastInDimOp>(
loc, transformed_operand, resultBuffer, op.broadcast_dimensions());
rewriter.create<lmhlo::CopyOp>(loc, transformed_operand, resultBuffer);
rewriter.replaceOp(op, {resultBuffer});
@ -211,48 +195,76 @@ struct HloToLhloDynamicBroadcastInDimOpConverter
auto loc = op.getLoc();
auto operand_type = operand.getType().cast<MemRefType>();
auto operand_shape = operand_type.getShape();
auto operand_rank = operand_type.getRank();
SmallVector<Value, 2> sizes, strides;
sizes.reserve(operand_shape.size());
strides.reserve(operand_shape.size());
auto result_type = op.getType().cast<RankedTensorType>();
auto result_rank = result_type.getRank();
Value zero = b->create<ConstantIndexOp>(loc, 0);
Value one = b->create<ConstantIndexOp>(loc, 1);
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
Value broadcast_dim_value =
b->create<ConstantIndexOp>(loc, dim.value().getSExtValue());
Value result_dim_size = b->create<ExtractElementOp>(
loc, op.output_dimensions(), broadcast_dim_value);
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[dim.index()])
? b->create<DimOp>(loc, operand, dim.index()).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[dim.index()])
.getResult();
// TODO(pifon): Revisit if this cast is needed. Maybe we can use
// tensor<index> for `output_dimensions` as well.
// Compute a reversed scan product. Compute the stride for the dimensions so
// far, working from minor to major dimensions. Additionally, save the
// operand shape Values to use in the next loop.
SmallVector<Value, 2> operand_strides(operand_rank, one);
SmallVector<Value, 2> operand_sizes(operand_rank, one);
Value stride_so_far = one;
for (int i = operand_rank - 1; i >= 0; --i) {
Value operand_dim_size =
ShapedType::isDynamic(operand_shape[i])
? b->create<DimOp>(loc, operand, i).getResult()
: b->create<ConstantIndexOp>(loc, operand_shape[i]).getResult();
operand_sizes[i] = operand_dim_size;
operand_strides[i] = stride_so_far;
if (i > 0) {
stride_so_far = b->create<MulIOp>(loc, stride_so_far, operand_dim_size);
}
}
SmallVector<Value, 2> sizes, strides;
sizes.reserve(result_rank);
strides.reserve(result_rank);
DenseMap<int, int> output_to_input_dim;
for (auto dim : llvm::enumerate(op.broadcast_dimensions())) {
output_to_input_dim[dim.value().getSExtValue()] = dim.index();
}
for (int i = 0; i < result_rank; ++i) {
Value i_val = b->create<ConstantIndexOp>(loc, i);
Value result_dim_size =
b->create<ExtractElementOp>(loc, op.output_dimensions(), i_val);
if (!result_dim_size.getType().isIndex()) {
result_dim_size =
b->create<IndexCastOp>(loc, result_dim_size, b->getIndexType());
}
sizes.push_back(result_dim_size);
auto it = output_to_input_dim.find(i);
// If the rank of the output is greater than the rank of the input, i.e.
// there was no output dimension in the inverse broadcast_dimensions map
// we also set stride to 0 to emulate padding of the shape with 1s and the
// corresponding expansion.
if (it == output_to_input_dim.end()) {
strides.push_back(zero);
continue;
}
// There can be two cases:
// 1) Operand dim == result dim => expansion is not needed => stride := 1.
// 1) Operand dim == result dim => expansion is not needed
// => stride flattened buffer stride
// 2) Operand dim < result dim => expansion is needed => stride := 0.
Value is_expansion = b->create<CmpIOp>(loc, CmpIPredicate::slt,
operand_dim_size, result_dim_size);
strides.push_back(
b->create<mlir::SelectOp>(loc, is_expansion, zero, one));
// Size of input dim can be set to the size of the corresponding output
// dimension for both cases.
sizes.push_back(result_dim_size);
int dim = it->second;
Value is_expansion = b->create<CmpIOp>(
loc, CmpIPredicate::slt, operand_sizes[dim], result_dim_size);
strides.push_back(b->create<mlir::SelectOp>(loc, is_expansion, zero,
operand_strides[dim]));
}
// Type-erased memref type with static rank, dynamic sizes and strides.
SmallVector<int64_t, 2> dynamic_layout(operand_shape.size(),
SmallVector<int64_t, 2> dynamic_layout(result_rank,
MemRefType::kDynamicStrideOrOffset);
SmallVector<int64_t, 2> dynamic_shape(operand_shape.size(),
SmallVector<int64_t, 2> dynamic_shape(result_rank,
MemRefType::kDynamicSize);
auto type_erased_memref_type = MemRefType::get(
dynamic_shape, operand_type.getElementType(),

View File

@ -1,4 +1,6 @@
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting -buffer-deallocation -split-input-file %s -o - | FILECHECK_OPTS="" FileCheck %s
// RUN: mlir-hlo-opt -hlo-legalize-to-lhlo -buffer-hoisting \
// RUN: -buffer-deallocation -split-input-file -cse %s -o - \
// RUN: | FILECHECK_OPTS="" FileCheck %s
// CHECK-LABEL: func @attrs
func @attrs_copy(%operand: memref<2x2xf32>, %result: memref<2x2xf32>) {
@ -153,64 +155,41 @@ func @broadcast(%operand: memref<5xf32>, %result: memref<10x5xf32>) {
// -----
func @external_func() -> tensor<3xi64>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1)[s0, s1] -> (d0 * s0 + d1 * s1)>
// CHECK: #[[MAP:.*]] = affine_map<(d0, d1, d2)[s0, s1, s2] -> (d0 * s0 + d1 * s1 + d2 * s2)>
// CHECK-LABEL: func @dyn_broadcast
func @dyn_broadcast(%operand: memref<?x?xf32>) {
// CHECK-SAME: (%[[OPERAND:.*]]: memref<?x?xf32>)
func @dyn_broadcast(%operand: memref<?x?xf32>) -> index {
// CHECK-SAME: %[[OPERAND:.*]]: memref<?x?xf32>
%tensor_operand = tensor_load %operand : memref<?x?xf32>
%c1 = constant 1 : i64
%shape = tensor_from_elements %c1, %c1, %c1 : tensor<3xi64>
%tensor_result = "mhlo.dynamic_broadcast_in_dim"(%tensor_operand, %shape) {
broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
} : (tensor<?x?xf32>, tensor<3xi64>) -> tensor<?x?x?xf32>
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<3xi64>
// CHECK: %[[IC0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<3xi64>
// CHECK: %[[IC1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]][%[[C2]]] : tensor<3xi64>
// CHECK: %[[IC2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[IC0]], %[[IC1]], %[[IC2]])
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[C1__:.*]] = constant 1 : index
// CHECK: %[[EL1_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1__]]] : tensor<3xi64>
// CHECK: %[[C0___:.*]] = constant 0 : index
// CHECK: %[[OPERAND_DIM_0:.*]] = dim %[[OPERAND]], %[[C0___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_1:.*]] = index_cast %[[EL1_]] : i64 to index
// CHECK: %[[EXPAND_0:.*]] = cmpi "slt", %[[OPERAND_DIM_0]], %[[RESULT_DIM_1]]
// CHECK: %[[STRIDE_0:.*]] = select %[[EXPAND_0]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[C2_:.*]] = constant 2 : index
// CHECK: %[[EL2_:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2_]]] : tensor<3xi64>
// CHECK: %[[C1___:.*]] = constant 1 : index
// CHECK: %[[OPERAND_DIM_1:.*]] = dim %[[OPERAND]], %[[C1___]] : memref<?x?xf32>
// CHECK: %[[RESULT_DIM_2:.*]] = index_cast %[[EL2_]] : i64 to index
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPERAND_DIM_1]], %[[RESULT_DIM_2]]
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0_]], %[[C1_]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to
// CHECK-SAME: offset: [0],
// CHECK-SAME: sizes: {{\[}}%[[RESULT_DIM_1]], %[[RESULT_DIM_2]]]
// CHECK-SAME: strides: {{\[}}%[[STRIDE_0]], %[[STRIDE_1]]]
// CHECK-SAME: : memref<?x?xf32> to memref<?x?xf32, #map>
// CHECK: "lmhlo.broadcast_in_dim"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) {
// CHECK-SAME: broadcast_dimensions = dense<[1, 2]> : tensor<2xi64>
// CHECK-SAME: } : (memref<?x?xf32, #[[MAP]]>, memref<?x?x?xf32>) -> ()
// Do not store the value back to avoid the tensor-store being rewritten to
// a copy into the pre-allocated argument.
return
%rank = rank %tensor_result : tensor<?x?x?xf32>
return %rank : index
}
// CHECK: %[[SHAPE:.*]] = tensor_from_elements
// CHECK: %[[C0:.*]] = constant 0 : index
// CHECK: %[[EL0:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C0]]] : tensor<3xi64>
// CHECK: %[[SIZE_0:.*]] = index_cast %[[EL0]] : i64 to index
// CHECK: %[[C1:.*]] = constant 1 : index
// CHECK: %[[EL1:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C1]]] : tensor<3xi64>
// CHECK: %[[SIZE_1:.*]] = index_cast %[[EL1]] : i64 to index
// CHECK: %[[C2:.*]] = constant 2 : index
// CHECK: %[[EL2:.*]] = extract_element %[[SHAPE]]{{\[}}%[[C2]]] : tensor<3xi64>
// CHECK: %[[SIZE_2:.*]] = index_cast %[[EL2]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]) : memref<?x?x?xf32>
// CHECK: %[[OPER_DIM_1:.*]] = dim %[[OPERAND]], %[[C1]] : memref<?x?xf32>
// CHECK: %[[OP_STRIDE_0:.*]] = muli %[[C1]], %[[OPER_DIM_1]] : index
// CHECK: %[[OPER_DIM_0:.*]] = dim %[[OPERAND]], %[[C0]] : memref<?x?xf32>
// CHECK: %[[EXPAND_1:.*]] = cmpi "slt", %[[OPER_DIM_0]], %[[SIZE_1]] : index
// CHECK: %[[STRIDE_1:.*]] = select %[[EXPAND_1]], %[[C0]], %[[OP_STRIDE_0]] : index
// CHECK: %[[EXPAND_2:.*]] = cmpi "slt", %[[OPER_DIM_1]], %[[SIZE_2]] : index
// CHECK: %[[STRIDE_2:.*]] = select %[[EXPAND_2]], %[[C0]], %[[C1]] : index
// CHECK: %[[TRANSFORMED_MEMREF:.*]] = memref_reinterpret_cast %[[OPERAND]] to offset: [0], sizes: {{\[}}%[[SIZE_0]], %[[SIZE_1]], %[[SIZE_2]]], strides: {{\[}}%[[C0]], %[[STRIDE_1]], %[[STRIDE_2]]]: memref<?x?xf32> to memref<?x?x?xf32, #map>
// CHECK: "lmhlo.copy"(%[[TRANSFORMED_MEMREF]], %[[RESULT]]) : (memref<?x?x?xf32, #map>, memref<?x?x?xf32>) -> ()
// CHECK: dealloc %[[RESULT]] : memref<?x?x?xf32>
// -----
@ -483,11 +462,9 @@ func @add_dyn(%lhs: tensor<?x?xf32>, %rhs: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.add"(%arg0, %arg1, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -508,11 +485,9 @@ func @tanh_dyn(%arg0: tensor<?x?xf32>) {
// CHECK: %[[DIM1:.*]] = dim %arg0, %[[C1]] : memref<?x?xf32>
// CHECK: %[[IC1:.*]] = index_cast %[[DIM1]] : index to i64
// CHECK: %[[SHAPE:.*]] = tensor_from_elements %[[IC0]], %[[IC1]] : tensor<2xi64>
// CHECK: %[[C0_:.*]] = constant 0 : index
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0_]]] : tensor<2xi64>
// CHECK: %[[EE0:.*]] = extract_element %[[SHAPE]][%[[C0]]] : tensor<2xi64>
// CHECK: %[[ICS0:.*]] = index_cast %[[EE0]] : i64 to index
// CHECK: %[[C1_:.*]] = constant 1 : index
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1_]]] : tensor<2xi64>
// CHECK: %[[EE1:.*]] = extract_element %[[SHAPE]][%[[C1]]] : tensor<2xi64>
// CHECK: %[[ICS1:.*]] = index_cast %[[EE1]] : i64 to index
// CHECK: %[[RESULT:.*]] = alloc(%[[ICS0]], %[[ICS1]])
// CHECK: "lmhlo.tanh"(%arg0, %[[RESULT]]) : (memref<?x?xf32>, memref<?x?xf32>) -> ()
@ -645,7 +620,7 @@ func @shape_assuming_memref(%arg0: tensor<?xf16>) -> tensor<?xf16> {
%4 = tensor_cast %3 : tensor<?xindex> to tensor<1xindex>
%5 = "mhlo.dynamic_broadcast_in_dim"(%0, %4) {broadcast_dimensions = dense<> : tensor<0xi64>} : (tensor<f16>, tensor<1xindex>) -> tensor<?xf16>
%6 = "mhlo.dynamic_broadcast_in_dim"(%arg0, %4) {broadcast_dimensions = dense<0> : tensor<1xi64>} : (tensor<?xf16>, tensor<1xindex>) -> tensor<?xf16>
// CHECK: "lmhlo.maximum"(%6, %9, %20) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
// CHECK: "lmhlo.maximum"(%{{.*}}, %{{.*}}, %{{.*}}) : (memref<?xf16>, memref<?xf16>, memref<?xf16>) -> ()
%7 = mhlo.maximum %5, %6 : tensor<?xf16>
// CHECK: shape.assuming_yield %{{.*}} : memref<?xf16>
shape.assuming_yield %7 : tensor<?xf16>

View File

@ -390,6 +390,7 @@ cc_library(
"transforms/generated_legalize_tf.inc",
"transforms/generated_lower_static_tensor_list.inc",
"transforms/generated_prepare_tf.inc",
"transforms/insert_call_once_op.cc",
"transforms/legalize_tf.cc",
"transforms/legalize_tf_while.cc",
"transforms/lower_static_tensor_list.cc",

View File

@ -453,6 +453,11 @@ class Translator {
mlir::TFL::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Build call once operator.
BufferOffset<tflite::Operator> BuildCallOnceOperator(
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results);
// Builds custom operators.
// Templated on a) data type of custom_option to be stored into flatbuffer,
// and b) TFL custom op type.
@ -787,6 +792,22 @@ BufferOffset<tflite::Operator> Translator::BuildIfOperator(
builtin_options);
}
BufferOffset<tflite::Operator> Translator::BuildCallOnceOperator(
mlir::TFL::CallOnceOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
auto opcode_index =
GetOpcodeIndex("call_once", tflite::BuiltinOperator_CALL_ONCE);
int init_subgraph_index =
subgraph_index_map_.at(op.session_init_function().str());
auto builtin_options =
tflite::CreateCallOnceOptions(builder_, init_subgraph_index).Union();
auto inputs = builder_.CreateVector(operands);
auto outputs = builder_.CreateVector(results);
return tflite::CreateOperator(builder_, opcode_index, inputs, outputs,
tflite::BuiltinOptions_CallOnceOptions,
builtin_options);
}
BufferOffset<tflite::Operator> Translator::BuildWhileOperator(
mlir::TF::WhileOp op, const std::vector<int32_t>& operands,
const std::vector<int32_t>& results) {
@ -1026,6 +1047,12 @@ Optional<BufferOffset<tflite::Operator>> Translator::BuildOperator(
return llvm::None;
}
if (*builtin_code == tflite::BuiltinOperator_CALL_ONCE) {
if (auto initOp = dyn_cast<mlir::TFL::CallOnceOp>(inst)) {
return BuildCallOnceOperator(initOp, operands, results);
}
}
std::string op_name = inst->getName().getStringRef().str();
uint32_t opcode_index = GetOpcodeIndex(op_name, *builtin_code);

View File

@ -448,13 +448,54 @@ StatusOr<Operation*> BuildExternalConstOp(const tflite::TensorT& tensor,
return op.getOperation();
}
// Gets a constant splat for the given value of type. Requires value to be of
// type static shaped RankedTensorType. `unique_index` is used to get the unique
// value for the attribute.
static mlir::ElementsAttr GetSplat(RankedTensorType type, int unique_index,
OpBuilder builder) {
mlir::Type element_ty = getElementTypeOrSelf(type);
if (element_ty.isSignlessInteger())
return DenseElementsAttr::get(
type, builder.getIntegerAttr(element_ty, unique_index));
if (element_ty.isa<mlir::FloatType>())
return DenseElementsAttr::get(
type, builder.getFloatAttr(element_ty, unique_index));
if (auto qtype = element_ty.dyn_cast<QuantizedType>()) {
mlir::RankedTensorType new_type =
RankedTensorType::get(type.getShape(), qtype.getStorageType());
return DenseElementsAttr::get(
new_type, builder.getIntegerAttr(qtype.getStorageType(), unique_index));
}
llvm_unreachable("unhandled element type");
}
// TODO(b/172664358): Creates a new op instead of reusing constant op.
// Creates a constant op to represent stateful variable. The function static
// variable `stateful_variable_idx` is used as a unique value for each constant
// to avoid CSEed. `tensor` is the data structure of flatbuffer. `shaped_type`
// is the ShapedType for the const op.
Operation* BuildVariableOp(const tflite::TensorT& tensor,
mlir::RankedTensorType shaped_type,
OpBuilder builder, Location loc) {
static int stateful_variable_idx = 0;
mlir::ElementsAttr value =
GetSplat(shaped_type, stateful_variable_idx++, builder);
if (IsQuantized(tensor)) {
auto op = builder.create<tfl::QConstOp>(
loc, mlir::TypeAttr::get(shaped_type), value);
return op.getOperation();
}
auto op = builder.create<tfl::ConstOp>(loc, value);
return op.getOperation();
}
StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
const std::vector<uint8_t>& buffer,
OpBuilder builder, Location loc) {
if (buffer.empty()) {
return errors::InvalidArgument("Constant's buffer may not be empty");
}
bool is_variable, OpBuilder builder,
Location loc) {
TF_ASSIGN_OR_RETURN(auto type, GetTensorType(tensor, builder,
/*shapeless_are_scalars=*/true,
/*is_constant=*/true));
@ -466,7 +507,9 @@ StatusOr<Operation*> BuildConstOp(const tflite::TensorT& tensor,
auto elem_type = shaped_type.getElementType();
mlir::ElementsAttr value;
if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
if (is_variable) {
return BuildVariableOp(tensor, shaped_type, builder, loc);
} else if (auto float_type = elem_type.dyn_cast<mlir::FloatType>()) {
TF_ASSIGN_OR_RETURN(value,
ConvertFloatBuffer(shaped_type, float_type, buffer));
} else if (elem_type.isa<mlir::IntegerType, QuantizedType>()) {
@ -846,19 +889,8 @@ StatusOr<FuncOp> ConvertSubgraph(
GetTensorIndices(subgraph, ordered_input_arrays));
}
// Add state variables to inputs.
absl::flat_hash_set<int32_t> input_index_set(func_inputs.begin(),
func_inputs.end());
for (int i = 0, end = subgraph.tensors.size(); i < end; i++) {
auto& tensor = *subgraph.tensors.at(i);
if (tensor.is_variable && !input_index_set.contains(i)) {
func_inputs.emplace_back(i);
input_index_set.insert(i);
}
}
for (auto input_or_variable : func_inputs) {
auto& tensor = *subgraph.tensors.at(input_or_variable);
for (int input : func_inputs) {
auto& tensor = *subgraph.tensors.at(input);
// TODO(b/138222071) Graph inputs must have static shape per the exporter,
// but we cannot differentiate scalars from unranked tensors.
// Here we reverse the default assumption that shape = [] means unranked.
@ -889,7 +921,8 @@ StatusOr<FuncOp> ConvertSubgraph(
}
for (auto output : func_outputs) {
const bool is_func_input = input_index_set.contains(output);
const bool is_func_input = std::find(func_inputs.begin(), func_inputs.end(),
output) != func_inputs.end();
bool is_constant = !is_op_output[output] && !is_func_input;
// There are 2 cases tensor is scalar when it doesn't have a shape in
// flatbuffer:
@ -991,7 +1024,7 @@ StatusOr<FuncOp> ConvertSubgraph(
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
op_builder, const_loc)
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
op_builder, const_loc);
const_tensor.is_variable, op_builder, const_loc);
if (!op_or_err.ok()) {
return emitError(const_loc, op_or_err.status().ToString()),
op_or_err.status();
@ -1051,7 +1084,7 @@ StatusOr<FuncOp> ConvertSubgraph(
? BuildExternalConstOp(const_tensor, const_tensor.buffer,
op_builder, const_loc)
: BuildConstOp(const_tensor, buffers[const_tensor.buffer]->data,
op_builder, const_loc);
const_tensor.is_variable, op_builder, const_loc);
if (!op_or_err.ok()) {
return emitError(const_loc, op_or_err.status().ToString()),
op_or_err.status();

View File

@ -1972,6 +1972,43 @@ OpFoldResult ConstOp::fold(ArrayRef<Attribute> operands) {
return value();
}
//===----------------------------------------------------------------------===//
// CastOp
//===----------------------------------------------------------------------===//
OpFoldResult CastOp::fold(ArrayRef<Attribute> operands) {
assert(operands.size() == 1);
// For now, only supports cast between integer types.
auto elements_attr = operands[0].dyn_cast_or_null<DenseIntElementsAttr>();
if (!elements_attr) {
return nullptr;
}
auto result_element_type =
getType().cast<ShapedType>().getElementType().dyn_cast<IntegerType>();
auto operand_element_type = input()
.getType()
.cast<ShapedType>()
.getElementType()
.dyn_cast<IntegerType>();
// Returns nullptr if either result/operand element type is not integer.
if (!result_element_type || !operand_element_type) {
return nullptr;
}
const bool is_input_unsigned = operand_element_type.isUnsigned();
const int output_bitwidth = result_element_type.getWidth();
// The integer cast op is the same as C integer cast. Depends on the operand
// type's signedness, we will determine whether or not sign extension is
// needed.
auto cast = [&](APInt value) {
return is_input_unsigned ? value.zextOrTrunc(output_bitwidth)
: value.sextOrTrunc(output_bitwidth);
};
return elements_attr.mapValues(result_element_type, cast);
}
//===----------------------------------------------------------------------===//
// SelectV2Op
//===----------------------------------------------------------------------===//

View File

@ -3405,7 +3405,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
}];
let arguments = (ins
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$input,
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$input,
TFL_I32Tensor:$begin,
TFL_I32Tensor:$end,
TFL_I32Tensor:$strides,
@ -3418,7 +3418,7 @@ def TFL_StridedSliceOp: TFL_Op<"strided_slice", [
);
let results = (outs
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8]>:$output
TFL_TensorOf<[F32, I32, I64, I8, UI8, QI8, QUI8, I1, I16, QI16, TFL_Quint8, TFL_Str]>:$output
);
let hasOptions = 1;
@ -3443,6 +3443,8 @@ def TFL_CastOp : TFL_Op<"cast", [
// TFLite's cast op does not utilize CastOptions, instead derives types
// from the TfLiteTensors.
let hasOptions = 0;
let hasFolder = 1;
}
def TFL_MirrorPadOp: TFL_Op<"mirror_pad", [
@ -4358,6 +4360,21 @@ def TFL_WhileOp : Op<TFL_Dialect, "while", [
let hasCanonicalizer = 1;
}
def TFL_CallOnceOp : TFL_Op<"call_once", []> {
let summary = "Invokes an initialization function";
let description = [{
This operation invokes the given initialization function for the session
initializer in tf saved model dialect.
}];
let arguments = (ins
StrAttr:$session_init_function
);
let results = (outs);
}
def TFL_CustomOp : Op<TFL_Dialect, "custom", [
NoSideEffect, NoQuantizableResult]> {
let summary = "Custom op";

View File

@ -52,6 +52,12 @@ struct QuantizationSpecs {
// weight FakeQuant).
bool disable_per_channel = false;
// When set to true, the fixed output ranges of the activation ops (tanh,
// sigmoid, etc.) are not enforced. Then, to quantize these ops, quantization
// emulation ops should be specified after the ops in the input graph. This
// flag should be set to false for post-training quantization.
bool disable_enforced_fixed_output_range = false;
// The node type when the model is exported. Currently this is limited to
// DT_FLOAT, DT_HALF, DT_QINT8, and DT_QUINT8. When DT_HALF is used, the
// `weight_quantization` flag needs to set to true. When DT_QUINT8 is used,

View File

@ -587,3 +587,55 @@ func @rsqrt_bf16() -> tensor<bf16> {
// CHECK: %[[CST:.*]] = constant dense<5.000000e-01> : tensor<bf16>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_i64_to_i32
func @cast_i64_to_i32() -> tensor<5xi32> {
%cst = constant dense<[-1, 0, 1, 2147483647, 2147483648]> : tensor<5xi64>
%0 = "tfl.cast"(%cst) : (tensor<5xi64>) -> tensor<5xi32>
return %0 : tensor<5xi32>
// CHECK: %[[CST:.*]] = constant dense<[-1, 0, 1, 2147483647, -2147483648]> : tensor<5xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_i32_to_ui8
func @cast_i32_to_ui8() -> tensor<6xui8> {
%cst = constant dense<[0, -1, 256, 127, -128, -129]> : tensor<6xi32>
%0 = "tfl.cast"(%cst) : (tensor<6xi32>) -> tensor<6xui8>
return %0 : tensor<6xui8>
// CHECK: %[[CST:.*]] = constant dense<[0, 255, 0, 127, 128, 127]> : tensor<6xui8>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_ui8_to_i8
func @cast_ui8_to_i8() -> tensor<4xi8> {
%cst = constant dense<[0, 255, 127, 128]> : tensor<4xui8>
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi8>
return %0 : tensor<4xi8>
// CHECK: %[[CST:.*]] = constant dense<[0, -1, 127, -128]> : tensor<4xi8>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_i8_to_i32
func @cast_i8_to_i32() -> tensor<4xi32> {
%cst = constant dense<[0, 128, -1, -128]> : tensor<4xi8>
%0 = "tfl.cast"(%cst) : (tensor<4xi8>) -> tensor<4xi32>
return %0 : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<[0, -128, -1, -128]> : tensor<4xi32>
// CHECK: return %[[CST]]
}
// CHECK-LABEL: @cast_ui8_to_i32
func @cast_ui8_to_i32() -> tensor<4xi32> {
%cst = constant dense<[0, 128, 129, 255]> : tensor<4xui8>
%0 = "tfl.cast"(%cst) : (tensor<4xui8>) -> tensor<4xi32>
return %0 : tensor<4xi32>
// CHECK: %[[CST:.*]] = constant dense<[0, 128, 129, 255]> : tensor<4xi32>
// CHECK: return %[[CST]]
}

View File

@ -411,11 +411,11 @@ versions {
# CHECK-NEXT: constant dense<[5.000000e+00, 6.000000e+00, 7.000000e+00, 8.000000e+00]>
# CHECK: "tf.If"{{.+}}else_branch = @cond_false_10{{.+}}is_stateless = true{{.+}}then_branch = @cond_true_10
# CHECK: "tf.If"{{.+}}else_branch = @cond_false0{{.+}}is_stateless = false{{.+}}then_branch = @cond_true0
# CHECK: func @cond_false_10
# CHECK: func private @cond_false_10
# CHECK-NEXT: tfl.div
# CHECK: func @cond_true_10
# CHECK: func private @cond_true_10
# CHECK-NEXT: tfl.sub
# CHECK: func @cond_false0
# CHECK: func private @cond_false0
# CHECK-NEXT: tfl.mul
# CHECK: func @cond_true0
# CHECK: func private @cond_true0
# CHECK-NEXT: tfl.add

View File

@ -78,14 +78,14 @@ versions {
}
# CHECK: func @main(%[[VAL_0:.*]]: tensor<2x5x3xf32>, %[[VAL_1:.*]]: tensor<3x7xf32>) -> tensor<2x5x7xf32> attributes {tf.entry_function = {control_outputs = "", inputs = "Placeholder,Placeholder_1", outputs = "MatMul"}} {
# CHECK: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
# CHECK: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
# CHECK: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
# CHECK: %[[VAL_5:.*]] = constant unit
# CHECK: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
# CHECK: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
# CHECK: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
# CHECK: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
# CHECK-DAG: %[[VAL_2:.*]] = constant dense<[1, 0]> : tensor<2xi32>
# CHECK-DAG: %[[VAL_3:.*]] = constant dense<[5, 3]> : tensor<2xi32>
# CHECK-DAG: %[[VAL_4:.*]] = constant dense<[3, 7]> : tensor<2xi32>
# CHECK-DAG: %[[VAL_5:.*]] = constant unit
# CHECK-DAG: %[[VAL_6:.*]] = constant dense<[1, 0, 0]> : tensor<3xi32>
# CHECK-DAG: %[[VAL_7:.*]] = constant dense<[1, 5, 3]> : tensor<3xi32>
# CHECK-DAG: %[[VAL_8:.*]] = constant dense<0> : tensor<3xi32>
# CHECK-DAG: %[[VAL_9:.*]] = constant dense<[1, 3, 7]> : tensor<3xi32>
# CHECK: %[[VAL_10:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_8]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>
# CHECK: %[[VAL_11:.*]] = "tfl.reshape"(%[[VAL_10]], %[[VAL_3]]) : (tensor<1x5x3xf32>, tensor<2xi32>) -> tensor<5x3xf32>
# CHECK: %[[VAL_12:.*]] = "tfl.slice"(%[[VAL_0]], %[[VAL_6]], %[[VAL_7]]) : (tensor<2x5x3xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<1x5x3xf32>

View File

@ -8,9 +8,11 @@ func @main(%arg0: tensor<1x4xf32>, %arg1: tensor<4x4xf32>, %arg2: tensor<4x4xf32
return %24 : tensor<1x4xf32>
// CHECK-LABEL: main
// seperate lines since there is no region for this op. third_party/tensorflow/compiler/mlir/lite/ir/tfl_ops.td: 3252
// CHECK: %[[RES0:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %arg22, %arg23, %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: %[[RES0:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
// CHECK: %[[RES1:.*]] = "tfl.pseudo_const"() {value = dense<{{.*}}> : tensor<1x4xf32>} : () -> tensor<1x4xf32>
// CHECK: %[[RES2:.*]] = "tfl.lstm"(%arg0, %arg1, %arg2, %arg3, %arg4, %arg5, %arg6, %arg7, %arg8, %arg9, %arg10, %arg11, %arg12, %arg13, %arg14, %arg15, %arg16, %arg17, %[[RES0]], %[[RES1]], %arg18, %arg19, %arg20, %arg21) ( {
// CHECK: }) {cell_clip = 0.000000e+00 : f32, fused_activation_function = "NONE", kernel_type = "FULL", proj_clip = 0.000000e+00 : f32} : (tensor<1x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4x4xf32>, tensor<4xf32>, tensor<1x4xf32>, tensor<1x4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>, tensor<4xf32>) -> tensor<1x4xf32>
// CHECK: return %[[RES0]]
// CHECK: return %[[RES2]]
}

View File

@ -5,8 +5,8 @@
// CHECK: func @main(%arg0: tensor<1xf32>) -> tensor<*xf32>
// CHECK: %0 = "tf.While"(%arg0) {body = @body, cond = @cond, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>
// CHECK: return %0 : tensor<*xf32>
// CHECK: func @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @cond(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @body(%arg0: tensor<*xf32>) -> tensor<*xf32>
func @main(%arg0: tensor<1xf32>) -> tensor<*xf32> {
%0 = "tf.While"(%arg0) {cond = @cond, body = @body, is_stateless = false} : (tensor<1xf32>) -> tensor<*xf32>

View File

@ -1,6 +1,6 @@
// RUN: tf-opt -tfl-prepare-composite-funcs-tf -tfl-fuse-tftext=true %s | FileCheck %s
func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@ -1026,11 +1026,11 @@ func @WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_true_23810(%arg0: t
return %1 : tensor<i1>
}
// CHECK: func @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
// CHECK: func private @whitespace_tokenizer_rank1(%arg0: tensor<1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<1>], tf.signature.is_stateful} {
// CHECK: %0:2 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>)
// CHECK: return %0#0, %0#1 : tensor<?x!tf.string>, tensor<?xi64>
func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._input_shapes = [#tf.shape<?x1>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%1 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
@ -2160,11 +2160,11 @@ func @WhitespaceTokenize_WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_As
// CHECK: func @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
// CHECK: func private @whitespace_tokenizer_rank2(%arg0: tensor<?x1x!tf.string> {tf._user_specified_name = "input"}) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<?x1>], tf.signature.is_stateful} {
// CHECK: %0:3 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<?x1x!tf.string>) -> (tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>)
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<?xi64>, tensor<?xi64>
func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._input_shapes = [#tf.shape<>], tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[0, 1]> : tensor<2xi64>} : () -> tensor<2xi64>
%1 = "tf.Const"() {value = dense<[]> : tensor<0xi64>} : () -> tensor<0xi64>
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
@ -3190,7 +3190,7 @@ func @WhitespaceTokenize_WhitespaceTokenize_RaggedGather_1_Assert_3_AssertGuard_
return %1 : tensor<i1>
}
// CHECK: func @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
// CHECK: func private @whitespace_tokenizer_rank0(%arg0: tensor<!tf.string> {tf._user_specified_name = "input"}) -> tensor<?x!tf.string> attributes {tf._implements = #tf.func<@"tftext:WhitespaceTokenizer", {}>, tf._input_shapes = [#tf.shape<>], tf.signature.is_stateful} {
// CHECK: %0 = "tfl.custom"(%arg0) {custom_code = "tftext:WhitespaceTokenizer", custom_option = opaque<"tfl", "0x"> : tensor<0xi8>} : (tensor<!tf.string>) -> tensor<?x!tf.string>
// CHECK: return %0 : tensor<?x!tf.string>
@ -3213,7 +3213,7 @@ func @ngrams(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "input"}) ->
// CHECK: return %0 : tensor<?x!tf.string>
// CHECK: }
func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>
%1 = "tf.Const"() {value = dense<-1> : tensor<i64>} : () -> tensor<i64>
%2 = "tf.Const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
@ -3330,12 +3330,12 @@ func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name
%71 = "tf.Identity"(%70) {device = ""} : (tensor<3xi64>) -> tensor<3xi64>
return %68, %71, %64 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_27770(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_27780(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3345,12 +3345,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_28130(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_28140(%arg0: tensor<i1>, %arg1: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3359,12 +3359,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_RowPartitionFromRowSplits_as
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
return %4 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28500(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28510(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits/strided_slice_1:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3374,12 +3374,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_assert_equal_1_Assert_Assert
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_true_28900(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_equal_1_Assert_AssertGuard_false_28910(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:zero"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3389,12 +3389,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_true_29260(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_assert_non_negative_assert_less_equal_Assert_AssertGuard_false_29270(%arg0: tensor<i1>, %arg1: tensor<2xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<2>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to from_row_splits do not form a valid RaggedTensor:monotonic"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x >= 0 did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/RowPartitionFromRowSplits/sub:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3403,12 +3403,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_RowPartitionFromRowSplits_
%4 = "tf.Identity"(%3) {device = ""} : (tensor<i1>) -> tensor<i1>
return %4 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_true_29650(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
func private @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_AssertGuard_false_29660(%arg0: tensor<i1>, %arg1: tensor<i64>, %arg2: tensor<i64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<>, #tf.shape<>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Arguments to _from_row_partition do not form a valid RaggedTensor"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (RaggedFromNestedRowSplits/RaggedFromRowSplits_1/strided_slice:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3418,12 +3418,12 @@ func @RaggedFromNestedRowSplits_RaggedFromRowSplits_1_assert_equal_1_Assert_Asse
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_true_30330(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>]} {
%0 = "tf.Identity"(%arg0) {device = ""} : (tensor<i1>) -> tensor<i1>
%1 = "tf.Identity"(%0) {device = ""} : (tensor<i1>) -> tensor<i1>
return %1 : tensor<i1>
}
func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {sym_visibility = "private", tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
func private @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_30340(%arg0: tensor<i1>, %arg1: tensor<?xi64>, %arg2: tensor<?xi64>) -> tensor<i1> attributes {tf._input_shapes = [#tf.shape<>, #tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<"Inputs must have identical ragged splits"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%1 = "tf.Const"() {value = dense<"Condition x == y did not hold element-wise:"> : tensor<!tf.string>} : () -> tensor<!tf.string>
%2 = "tf.Const"() {value = dense<"x (NGrams/SlidingWindow/RaggedGetItem/RaggedRange:0) = "> : tensor<!tf.string>} : () -> tensor<!tf.string>
@ -3433,12 +3433,12 @@ func @NGrams_SlidingWindow_RaggedConcat_assert_equal_2_Assert_AssertGuard_false_
%5 = "tf.Identity"(%4) {device = ""} : (tensor<i1>) -> tensor<i1>
return %5 : tensor<i1>
}
// CHECK: func @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: func private @ngrams_ragged_rank_2(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<3xi64> {tf._user_specified_name = "args_0"}, %arg2: tensor<?xi64> {tf._user_specified_name = "args_1"}) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) attributes {tf._implements = #tf.func<@"tftext:Ngrams", {axis = -1 : i64, reduction_type = "STRING_JOIN", string_separator = "", width = 2 : i64}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<3>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: %0:3 = "tfl.custom"(%arg0, %arg1, %arg2) {custom_code = "tftext:Ngrams", custom_option = opaque<"tfl", "0x776964746800737472696E675F736570617261746F720000006178697300726564756374696F6E5F74797065000B535452494E475F4A4F494E0004221E373E040104FF152C0204141404082401"> : tensor<77xi8>} : (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>) -> (tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>)
// CHECK: return %0#0, %0#1, %0#2 : tensor<?x!tf.string>, tensor<3xi64>, tensor<?xi64>
func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
%0 = "tf.Const"() {value = dense<[[1902835825], [-1475704015], [473120514], [1254202069], [1558833093], [1756181982], [1906603252], [-1034142694], [542842690], [535515822]]> : tensor<10x1xi64>} : () -> tensor<10x1xi64>
%1 = "tf.StringToHashBucketFast"(%arg0) {device = "", num_buckets = 2147483647 : i64} : (tensor<?x!tf.string>) -> tensor<?xi64>
%2 = "tf.Sgnn"(%1, %0) {device = ""} : (tensor<?xi64>, tensor<10x1xi64>) -> tensor<10x?xf64>
@ -3448,6 +3448,6 @@ func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "va
}
// CHECK: func @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {sym_visibility = "private", tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: func private @sgnn_projection(%arg0: tensor<?x!tf.string> {tf._user_specified_name = "values"}, %arg1: tensor<?xi64> {tf._user_specified_name = "row_splits"}) -> tensor<?x10xf64> attributes {tf._implements = #tf.func<@"tftext:custom:SgnnProjection", {buckets = 2147483647 : i64, hash_seed = [1902835825, -1475704015, 473120514, 1254202069, 1558833093, 1756181982, 1906603252, -1034142694, 542842690, 535515822]}>, tf._input_shapes = [#tf.shape<?>, #tf.shape<?>], tf.signature.is_stateful} {
// CHECK: %0 = "tfl.custom"(%arg0, %arg1) {custom_code = "tftext:custom:SgnnProjection", custom_option = opaque<"tfl", "0x686173685F736565640000000A00000071F86A71318B0AA8023F331CD59AC14AC5E7E95CDE35AD68F474A4711A3C5CC2421F5B20AE52EB1F6275636B6574730002094200030000000100000002000000FFFFFF7F44000000062E0A2601"> : tensor<93xi8>} : (tensor<?x!tf.string>, tensor<?xi64>) -> tensor<?x10xf64>
// CHECK: return %0 : tensor<?x10xf64>

View File

@ -0,0 +1,40 @@
// RUN: tf-opt -split-input-file -tfl-insert-call-once-op %s | FileCheck %s
// Tests that new call_once op is added when there is a session initializer.
module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() {initializers = [@init_all_tables]} : () -> ()
func @init_all_tables()
attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
%cst = constant dense<[1, 2, 3, 4]> : tensor<4xi64>
%cst_0 = constant dense<["a", "b", "c", "d"]> : tensor<4x!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
"tf.LookupTableImportV2"(%0, %cst, %cst_0) {device = ""} : (tensor<!tf.resource>, tensor<4xi64>, tensor<4x!tf.string>) -> ()
return
// CHECK-LABEL: @init_all_tables
}
func @serving_default(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<*x!tf.string> {tf_saved_model.index_path = ["r"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "hash_table_Lookup/LookupTableFindV2:0"}, tf_saved_model.exported_names = ["serving_default"]} {
%cst = constant dense<"f"> : tensor<!tf.string>
%0 = "tf.HashTableV2"() {container = "", device = "", key_dtype = i64, shared_name = "hash_table_dba2ccaa-f1b1-46d6-b276-98008f69da71", use_node_name_sharing = false, value_dtype = !tf.string} : () -> tensor<!tf.resource>
%1 = "tf.LookupTableFindV2"(%0, %arg0, %cst) {device = ""} : (tensor<!tf.resource>, tensor<i64>, tensor<!tf.string>) -> tensor<*x!tf.string>
return %1 : tensor<*x!tf.string>
// CHECK-LABEL: @serving_default
// CHECK: "tfl.call_once"() {session_init_function = "init_all_tables"} : () -> ()
}
}
// -----
// Tests that no call_once op is added.
module attributes {tf_saved_model.semantics} {
func @no_call_once(%arg0: tensor<i64> {tf_saved_model.index_path = ["x"]}) -> (tensor<i64> {tf_saved_model.index_path = ["r"]})
attributes {tf.entry_function = {control_outputs = "", inputs = "input:0", outputs = "output:0"}, tf_saved_model.exported_names = ["serving_default"]} {
return %arg0 : tensor<i64>
// CHECK-LABEL: no_call_once
// CHECK-NOT: "tfl.call_once"
}
}

View File

@ -1122,6 +1122,13 @@ func @strided_slice_with_constant_attributes(%arg0: tensor<10x10x10xf32>, %arg1:
// CHECK-NEXT: "tfl.strided_slice"(%arg0, [[BEGIN]], [[END]], [[STRIDES]]) {begin_mask = 6 : i32, ellipsis_mask = 0 : i32, end_mask = 6 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 1 : i32} : (tensor<10x10x10xf32>, tensor<3xi32>, tensor<3xi32>, tensor<3xi32>) -> tensor<10x10xf32>
}
func @strided_slice_with_string(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
%0 = "tf.StridedSlice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i64, ellipsis_mask = 0 : i64, end_mask = 0 : i64, new_axis_mask = 0 : i64, shrink_axis_mask = 0 : i64} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
return %0 : tensor<1x2x2x5x!tf.string>
// CHECK-LABEL: strided_slice_with_string
// CHECK: "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
}
func @slice1Tensor(%arg0: tensor<2x3x5xf32>, %arg1: tensor<3xi32>, %arg2: tensor<3xi32>) -> tensor<?x3x5xf32> {
%0 = "tf.Slice"(%arg0, %arg1, %arg2) : (tensor<2x3x5xf32>, tensor<3xi32>, tensor<3xi32>) -> tensor<?x3x5xf32>
return %0 : tensor<?x3x5xf32>

View File

@ -1458,6 +1458,12 @@ func @testStridedSliceTFType(%arg0: tensor<12x2x2x5xui8>, %arg1: tensor<1xi32>,
return %0 : tensor<1x2x2x5x!tf.quint8>
}
// CHECK-LABEL: testStridedSliceWithString
func @testStridedSliceWithString(%arg0: tensor<12x2x2x5x!tf.string>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5x!tf.string> {
%0 = "tfl.strided_slice"(%arg0, %arg1, %arg2, %arg3) {begin_mask = 0 : i32, ellipsis_mask = 0 : i32, end_mask = 0 : i32, new_axis_mask = 0 : i32, shrink_axis_mask = 0 : i32} : (tensor<12x2x2x5x!tf.string>, tensor<1xi32>, tensor<1xi32>, tensor<1xi32>) -> tensor<1x2x2x5x!tf.string>
return %0 : tensor<1x2x2x5x!tf.string>
}
// -----
func @testStridedSliceWithInvalidOutputType(%arg0: tensor<12x2x2x5xf32>, %arg1: tensor<1xi32>, %arg2: tensor<1xi32>, %arg3: tensor<1xi32>) -> tensor<1x2x2x5xi32> {

View File

@ -407,16 +407,16 @@ func @fuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112
}
// CHECK-LABEL: @notFuseMulIntoDepthwiseConv2d
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
func @notFuseMulIntoDepthwiseConv2d(%arg0: tensor<1x4x4x2xf32>) -> tensor<1x4x4x2xf32> {
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0], [5.0, 6.0]], [[7.0, 8.0], [9.0, 10.0], [11.0, 12.0]], [[13.0, 14.0], [15.0, 16.0], [17.0, 18.0]]]]> : tensor<1x3x3x2xf32>
%cst1 = constant dense<2.0> : tensor<2xf32>
%cst2 = constant dense<3.0> : tensor<112x2xf32>
%cst2 = constant dense<[[3.1, 3.2], [3.1, 3.2], [3.1, 3.2], [3.1, 3.2]]> : tensor<4x2xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x112x112x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
%0 = "tfl.depthwise_conv_2d"(%arg0, %cst0, %cst1) {depth_multiplier = 1 : i32, dilation_h_factor = 1 : i32, dilation_w_factor = 1 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 1 : i32, stride_w = 1 : i32} : (tensor<1x4x4x2xf32>, tensor<1x3x3x2xf32>, tensor<2xf32>) -> tensor<1x4x4x2xf32>
// We cannot fuse this tfl.mul into the preceding conv op because %cst2 is not broadcast-compatible to %cst0.
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x112x112x2xf32>, tensor<112x2xf32>) -> tensor<1x112x112x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "RELU6"} : (tensor<1x4x4x2xf32>, tensor<4x2xf32>) -> tensor<1x4x4x2xf32>
return %1 : tensor<1x112x112x2xf32>
return %1 : tensor<1x4x4x2xf32>
// CHECK: %0 = "tfl.depthwise_conv_2d"(%arg0, %cst, %cst_0)
// CHECK: %1 = "tfl.mul"(%0, %cst_1)
@ -484,17 +484,17 @@ func @FuseFullyConnectedAddWithScalarRhs(%arg0: tensor<40x37xf32>, %arg1: tensor
}
// CHECK-LABEL: @FuseFullyConnectedAddWithUnfusableRhs
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
func @FuseFullyConnectedAddWithUnfusableRhs(%arg0: tensor<4x37xf32>, %arg1: tensor<4x37xf32>) -> tensor<4x4xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<40x40xf32>
%cst2 = constant dense<[[2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3], [2.0, 2.1, 2.2, 2.3]]> : tensor<4x4xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<4x37xf32>, tensor<4x37xf32>, none) -> (tensor<4x4xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<4x4xf32>, tensor<4x4xf32>) -> tensor<4x4xf32>
return %1 : tensor<40x40xf32>
return %1 : tensor<4x4xf32>
// CHECK: %[[unit:.*]] = constant unit
// CHECK: %[[filter:.*]] = constant dense<2.000000e+00> : tensor<40x40xf32>
// CHECK: %[[filter:.*]] = constant dense<{{.*}}> : tensor<4x4xf32>
// CHECK: %[[fc_result:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[unit]])
// CHECK: %[[add_result:.*]] = tfl.add %[[fc_result]], %[[filter]]
// CHECK: return %[[add_result]]
@ -851,17 +851,17 @@ func @fuseDivIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x1
}
// CHECK-LABEL: @fuseMulIntoConv2d_Scalar
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
func @fuseMulIntoConv2d_Scalar(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x1xf32> {
%cst0 = constant dense<[[[[1.0, 2.0], [3.0, 4.0]], [[5.0, 6.0], [7.0, 8.0]]]]> : tensor<1x2x2x2xf32>
%cst1 = constant dense<1.0> : tensor<2xf32>
%cst1 = constant dense<1.0> : tensor<1xf32>
%cst2 = constant dense<2.0> : tensor<f32>
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<f32>) -> tensor<1x112x112x2xf32>
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x1xf32>, tensor<f32>) -> tensor<1x112x112x1xf32>
return %1 : tensor<1x112x112x2xf32>
return %1 : tensor<1x112x112x1xf32>
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00], [6.000000e+00, 8.000000e+00]], {{\[\[}}1.000000e+01, 1.200000e+01], [1.400000e+01, 1.600000e+01]]]]> : tensor<1x2x2x2xf32>
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<1xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<1x2x2x2xf32>, tensor<1xf32>) -> tensor<1x112x112x1xf32>
// CHECK: return %[[RES]]
}
@ -1397,3 +1397,33 @@ func @fuseExpanded1DMulIntoConv2d(%arg0: tensor<1x8x8x207xf32>) -> tensor<1x8x8x
// CHECK: "tfl.conv_2d"(%arg0, %[[CST_0]], %[[CST_1]])
}
// CHECK-LABEL: @FuseFullyConnectedAddWithSplat2D
func @FuseFullyConnectedAddWithSplat2D(%arg0: tensor<40x37xf32>, %arg1: tensor<40x37xf32>) -> tensor<40x40xf32> {
%cst = constant unit
%cst2 = constant dense<2.0> : tensor<40x40xf32>
%0 = "tfl.fully_connected" (%arg0, %arg1, %cst) {fused_activation_function = "NONE", keep_num_dims = false, weights_format = "DEFAULT"} : (tensor<40x37xf32>, tensor<40x37xf32>, none) -> (tensor<40x40xf32>)
%1 = "tfl.add"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<40x40xf32>, tensor<40x40xf32>) -> tensor<40x40xf32>
return %1 : tensor<40x40xf32>
// CHECK: %[[BIAS:.*]] = constant dense<2.000000e+00> : tensor<40xf32>
// CHECK: %[[FC_RESULT:.*]] = "tfl.fully_connected"(%arg0, %arg1, %[[BIAS]])
// CHECK: return %[[FC_RESULT]]
}
// CHECK-LABEL: @fuseMulIntoConv2d_Splat2D
func @fuseMulIntoConv2d_Splat2D(%arg0: tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32> {
%cst0 = constant dense<[[[[1.0, 2.0]]], [[[3.0, 4.0]]]]> : tensor<2x1x1x2xf32>
%cst1 = constant dense<1.0> : tensor<2xf32>
%cst2 = constant dense<2.0> : tensor<1x112x112x2xf32>
%0 = "tfl.conv_2d"(%arg0, %cst0, %cst1) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
%1 = "tfl.mul"(%0, %cst2) {fused_activation_function = "NONE"} : (tensor<1x112x112x2xf32>, tensor<1x112x112x2xf32>) -> tensor<1x112x112x2xf32>
return %1 : tensor<1x112x112x2xf32>
// CHECK: %[[CST1:.*]] = constant dense<{{\[\[\[\[}}2.000000e+00, 4.000000e+00]]], {{\[\[\[}}6.000000e+00, 8.000000e+00]]]]> : tensor<2x1x1x2xf32>
// CHECK: %[[CST2:.*]] = constant dense<2.000000e+00> : tensor<2xf32>
// CHECK: %[[RES:[0-9].*]] = "tfl.conv_2d"(%arg0, %[[CST1]], %[[CST2]]) {dilation_h_factor = 2 : i32, dilation_w_factor = 3 : i32, fused_activation_function = "NONE", padding = "SAME", stride_h = 4 : i32, stride_w = 5 : i32} : (tensor<1x112x112x2xf32>, tensor<2x1x1x2xf32>, tensor<2xf32>) -> tensor<1x112x112x2xf32>
// CHECK: return %[[RES]]
}

View File

@ -77,3 +77,32 @@ func @HandleReturnedDequantizeWithAnotherUse(%arg0: tensor<128x16xf32>) -> (tens
// CHECK-NEXT: return %[[softmax]], %[[argmax]] : tensor<128x16xf32>, tensor<128xi32>
return %2, %3 : tensor<128x16xf32>, tensor<128xi32>
}
// CHECK-LABEL: PruneUnusedLstm
func @PruneUnusedLstm(%arg0: tensor<1x28x28xf32>) -> (tensor<1x28x28xf32>) {
%input = "tfl.quantize"(%arg0) {qtype = tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>} : (tensor<1x28x28xf32>) -> tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>
%cst_1 = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi8>} : () -> tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>
%cst_2 = constant unit
%cst_3 = "tfl.pseudo_qconst"() {qtype = tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x20xi8>} : () -> tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>
%cst_7 = "tfl.pseudo_qconst"() {qtype = tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20xi8>} : () -> tensor<20x!quant.uniform<i8:f32, 0.006:-34>>
%cst_11 = "tfl.pseudo_qconst"() {qtype = tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, value = dense<1> : tensor<20x28xi8>} : () -> tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>
%cell_input = "tfl.pseudo_qconst"() {qtype = tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>, value = dense<1> : tensor<1x20xi6>} : () -> tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>
%0 = "tfl.unidirectional_sequence_lstm"(%input,
%cst_11, %cst_11, %cst_11, %cst_11,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_2, %cst_2, %cst_2,
%cst_7, %cst_7, %cst_7, %cst_7,
%cst_2, %cst_2,
%cst_1, %cell_input,
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
: ( tensor<1x28x28x!quant.uniform<i8:f32, 0.003:-128>>,
tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x28x!quant.uniform<i8:f32, 0.006:-34>>,
tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x20x!quant.uniform<i8:f32, 0.006:-34>>,
none, none, none,
tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<20x!quant.uniform<i8:f32, 0.006:-34>>,
none, none,
tensor<1x20x!quant.uniform<i8:f32, 0.006:-34>>, tensor<1x20x!quant.uniform<i16:f32, 0.006:-34>>,
none, none, none, none) -> tensor<1x28x20x!quant.uniform<i8:f32, 0.006:-34>>
return %arg0 : tensor<1x28x28xf32>
// CHECK-NEXT: return %arg0
}

View File

@ -166,3 +166,37 @@ func @QuantizeTransposeConv(%arg0: tensor<32x4x4x128xf32>, %arg1: tensor<4xi32>)
// PerTensor: %[[DEQUANTIZE:.*]] = "tfl.dequantize"(%[[QUANTIZE]]) : (tensor<1x32x42x128x!quant.uniform<i8<-127:127>:f32, 1.000000e+00>>) -> tensor<1x32x42x128xf32>
// PerTensor: "tfl.transpose_conv"(%arg1, %arg0, %[[DEQUANTIZE]]
}
// CHECK-LABEL: QuantizeLstmCellInput
func @QuantizeLstmCellInput(%arg0: tensor<1x28x28xf32>) -> tensor<1x28x20xf32> {
%cst_1 = constant dense<1.0> : tensor<1x20xf32>
%cst_2 = constant unit
%cst_3 = constant dense<1.0> : tensor<20x20xf32>
%cst_7 = constant dense<1.0> : tensor<20xf32>
%cst_11 = constant dense<1.0> : tensor<20x28xf32>
%cell_input = constant dense<0.0> : tensor<1x20xf32>
%cell_stats = "quant.stats"(%cell_input) {layerStats = dense<[-2.73090601, 7.94872093]> : tensor<2xf32>} : (tensor<1x20xf32>) -> tensor<1x20xf32>
%0 = "tfl.unidirectional_sequence_lstm"(%arg0,
%cst_11, %cst_11, %cst_11, %cst_11,
%cst_3, %cst_3, %cst_3, %cst_3,
%cst_2, %cst_2, %cst_2,
%cst_7, %cst_7, %cst_7, %cst_7,
%cst_2, %cst_2,
%cst_1, %cell_stats,
%cst_2, %cst_2, %cst_2, %cst_2) {cell_clip = 1.000000e+01 : f32, fused_activation_function = "TANH", proj_clip = 0.000000e+00 : f32, time_major = false}
: ( tensor<1x28x28xf32>,
tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>, tensor<20x28xf32>,
tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>, tensor<20x20xf32>,
none, none, none,
tensor<20xf32>, tensor<20xf32>, tensor<20xf32>, tensor<20xf32>,
none, none,
tensor<1x20xf32>, tensor<1x20xf32>,
none, none, none, none) -> tensor<1x28x20xf32>
return %0 : tensor<1x28x20xf32>
// CHECK: %[[none:.*]] = constant unit
// CHECK: %[[cell_input:.*]] = constant dense<0.000000e+00> : tensor<1x20xf32>
// CHECK: %[[q:.*]] = "tfl.quantize"(%[[cell_input]]) {qtype = tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>} : (tensor<1x20xf32>) -> tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>
// CHECK: %[[dq:.*]] = "tfl.dequantize"(%[[q]]) : (tensor<1x20x!quant.uniform<i16:f32, 2.44140625E-4>>) -> tensor<1x20xf32>
// Checks if input 19 is correctly passed from a dequantize op.
// CHECK: %[[lstm:.*]] = "tfl.unidirectional_sequence_lstm"(%arg0, {{(%[^%,]+, )+}}%[[dq]], %[[none]], %[[none]], %[[none]], %[[none]])
}

View File

@ -30,9 +30,9 @@ func @while() -> tensor<1xf32>
}) : (tensor<i32>, tensor<1xf32>) -> (tensor<i32>, tensor<1xf32>) loc("WhileOp")
return %0#1 : tensor<1xf32>
}
// CHECK-LABEL: func @WhileOp_cond(
// CHECK-LABEL: func private @WhileOp_cond(
// CHECK: tfl.greater
// CHECK-LABEL: func @WhileOp_body(
// CHECK-LABEL: func private @WhileOp_body(
// CHECK: tfl.sub
// CHECK: tfl.add
@ -63,21 +63,21 @@ func @while2(%cst : tensor<i32>) -> tensor<1xf32> attributes {tf.entry_function
return %0#1 : tensor<1xf32>
}
func @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> attributes {sym_visibility = "private"} {
func private @WhileOp_cond(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> tensor<i1> {
%cst = constant dense<0> : tensor<i32>
%0 = "tfl.greater"(%arg0, %cst) : (tensor<*xi32>, tensor<i32>) -> tensor<i1>
return %0 : tensor<i1>
}
func @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) attributes {sym_visibility = "private"} {
func private @WhileOp_body(%arg0: tensor<*xi32>, %arg1: tensor<*xf32>, %arg2: tensor<i32>) -> (tensor<*xi32>, tensor<*xf32>, tensor<i32>) {
%0 = "tfl.sub"(%arg0, %arg2) {fused_activation_function = "NONE"} : (tensor<*xi32>, tensor<i32>) -> tensor<*xi32>
%1 = tfl.add %arg1, %arg1 {fused_activation_function = "NONE"} : tensor<*xf32>
return %0, %1, %arg2 : tensor<*xi32>, tensor<*xf32>, tensor<i32>
}
// CHECK-LABEL: func @WhileOp_cond(
// CHECK-LABEL: func private @WhileOp_cond(
// CHECK: tfl.greater
// CHECK-LABEL: func @WhileOp_body(
// CHECK-LABEL: func private @WhileOp_body(
// CHECK: tfl.sub
// CHECK: tfl.add
@ -152,14 +152,14 @@ func @rnn(%arg0: tensor<4x4x3xf32> {tf.device = "/device:CPU:0"}) -> tensor<4x?x
// CHECK: tfl.yield
// CHECK-SAME: (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) -> ()
// CHECK-LABEL: func @tfl.while_cond(
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> attributes {sym_visibility = "private"} {
// CHECK-LABEL: func private @tfl.while_cond(
// CHECK-SAME: [[VAL_35:%.*]]: tensor<i32>, [[VAL_36:%.*]]: tensor<i32>, [[VAL_37:%.*]]: tensor<*xf32>, [[VAL_38:%.*]]: tensor<4x2xf32>, [[VAL_39:%.*]]: tensor<4x2xf32>, [[VAL_40:%.*]]: tensor<*xf32>, [[VAL_41:%.*]]: tensor<4x4x3xf32>) -> tensor<i1> {
// CHECK: return
// CHECK-SAME: tensor<i1>
// CHECK: }
// CHECK-LABEL: func @tfl.while_body(
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) attributes {sym_visibility = "private"} {
// CHECK-LABEL: func private @tfl.while_body(
// CHECK-SAME: [[VAL_46:%.*]]: tensor<i32>, [[VAL_47:%.*]]: tensor<i32>, [[VAL_48:%.*]]: tensor<*xf32>, [[VAL_49:%.*]]: tensor<4x2xf32>, [[VAL_50:%.*]]: tensor<4x2xf32>, [[VAL_51:%.*]]: tensor<*xf32>, [[VAL_52:%.*]]: tensor<4x4x3xf32>) -> (tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>) {
// CHECK: [[VAL_91:%.*]] = "tfl.cast"
// CHECK: return
// CHECK-SAME: [[VAL_91]], [[VAL_52]] : tensor<i32>, tensor<i32>, tensor<*xf32>, tensor<4x2xf32>, tensor<4x2xf32>, tensor<*xf32>, tensor<4x4x3xf32>

View File

@ -234,6 +234,11 @@ void AddTFToTFLConversionPasses(const mlir::TFL::PassConfig& pass_config,
// tf.variable to model this.
pass_manager->addNestedPass<mlir::FuncOp>(
mlir::TFL::CreateSplitMergedOperandsPass());
// Add CallOnceOp when there is a session initializer function in tf saved
// model dialect.
pass_manager->addPass(
mlir::TFL::CreateInsertCallOnceOpFromSessionInitializerPass());
}
}

View File

@ -0,0 +1,78 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "mlir/IR/OperationSupport.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "tensorflow/compiler/mlir/lite/ir/tfl_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_saved_model.h"
namespace mlir {
namespace TFL {
namespace {
// This pass inserts a TFL::CallOnce op when tf_saved_model's session
// initializer is given.
class InsertCallOnceOpFromSessionInitializerPass
: public mlir::PassWrapper<InsertCallOnceOpFromSessionInitializerPass,
OperationPass<ModuleOp>> {
private:
void runOnOperation() override;
};
void InsertCallOnceOpFromSessionInitializerPass::runOnOperation() {
ModuleOp module = getOperation();
tf_saved_model::SessionInitializerOp session_init_op =
tf_saved_model::GetSessionInitializerOp(module);
if (!session_init_op) return;
SymbolTable symbol_table(module);
for (auto sym_ref : session_init_op.initializers()) {
FuncOp init_func_op = symbol_table.lookup<mlir::FuncOp>(
sym_ref.cast<FlatSymbolRefAttr>().getValue());
if (!init_func_op) {
module.emitError("no session initializer function found");
return signalPassFailure();
}
for (auto func : module.getOps<FuncOp>()) {
auto dict_attr =
func.getAttrOfType<mlir::DictionaryAttr>("tf.entry_function");
if (!dict_attr) continue;
OpBuilder builder(func.getContext());
builder.setInsertionPointToStart(&func.getBlocks().front());
builder.create<TFL::CallOnceOp>(func.getLoc(), init_func_op.getName());
}
}
}
} // namespace
// Inserts a TFL::CallOnce op when tf_saved_model's session initializer is
// given.
std::unique_ptr<OperationPass<ModuleOp>>
CreateInsertCallOnceOpFromSessionInitializerPass() {
return std::make_unique<InsertCallOnceOpFromSessionInitializerPass>();
}
static PassRegistration<InsertCallOnceOpFromSessionInitializerPass> pass(
"tfl-insert-call-once-op",
"Insert CallOnce op when tf_saved_model's session initializer is given");
} // namespace TFL
} // namespace mlir

View File

@ -27,11 +27,14 @@ limitations under the License.
#include "llvm/ADT/APFloat.h"
#include "llvm/ADT/APInt.h"
#include "llvm/ADT/ArrayRef.h"
#include "llvm/ADT/None.h"
#include "llvm/ADT/Optional.h"
#include "llvm/ADT/SmallSet.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/ADT/StringSwitch.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/raw_ostream.h"
#include "mlir/Dialect/StandardOps/IR/Ops.h" // from @llvm-project
#include "mlir/IR/Attributes.h" // from @llvm-project
#include "mlir/IR/Matchers.h" // from @llvm-project
@ -729,6 +732,143 @@ struct FuseBinaryOpToFollowingAffineOp : public OpRewritePattern<AffineOpType> {
}
};
// If the operand to a broadcastable op is a splat constant, try to replace it
// with a 0-d constant, e.g. before this optimization,
// %cst = constant dense<1.0> : tensor<16x16x4xf32>
// %0 = "tfl.conv_2d"...
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<16x16x4xf32>)
// After this optimization:
// %cst = constant dense<1.0> : tensor<f32>
// %0 = "tfl.conv_2d"...
// %1 = "tfl.add"(%0, %cst) : (tensor<16x16x4xf32>, tensor<f32>)
// This pattern can enable more fusing opportunities when the binary op is
// following conv ops.
template <typename BinaryOpType>
struct ScalarizeSplatConstantForBroadcastableOps
: public OpRewritePattern<BinaryOpType> {
using OpRewritePattern<BinaryOpType>::OpRewritePattern;
LogicalResult matchAndRewrite(BinaryOpType binary_op,
PatternRewriter &rewriter) const override {
DenseElementsAttr splat_elements_attr;
if (!IsScalarizableSplatConstant(binary_op.rhs(), &splat_elements_attr)) {
return failure();
}
constexpr int kSplatOperandIndex = 1;
auto result_type =
binary_op.getResult().getType().template cast<ShapedType>();
mlir::Value non_splat_operand =
binary_op.getOperand(1 - kSplatOperandIndex);
auto non_splat_operand_type =
non_splat_operand.getType().cast<ShapedType>();
// If the other operand's shape does not equal to the result shape, then we
// cannot scalarize the splat constant because the result shape relies on
// the splat constant op's shape for broadcasting.
if (!non_splat_operand_type.hasStaticShape() ||
non_splat_operand_type.getShape() != result_type.getShape()) {
return failure();
}
// If non-splat operand is not fusable affine ops, then no need to apply
// this transformation.
if (!CanFuseAffineOp(non_splat_operand.getDefiningOp(), binary_op)) {
return failure();
}
// Creates a new scalar constant op using the splat value.
mlir::Value splat_operand = binary_op.getOperand(kSplatOperandIndex);
auto scalar_elements_attr = DenseElementsAttr::get(
RankedTensorType::get({},
splat_elements_attr.getType().getElementType()),
splat_elements_attr.getSplatValue());
auto scalar_constant_op = rewriter.create<ConstantOp>(
splat_operand.getLoc(), scalar_elements_attr.getType(),
scalar_elements_attr);
binary_op.setOperand(kSplatOperandIndex, scalar_constant_op);
return success();
}
private:
// Returns true if this value is a splat constant op which can be scalarized.
// Also returns the elements attr if this value is indeed a splat constant.
bool IsScalarizableSplatConstant(mlir::Value value,
DenseElementsAttr *elements_attr) const {
if (!matchPattern(value, m_Constant(elements_attr))) {
return false;
}
auto element_type = value.getType().cast<ShapedType>().getElementType();
// Ignore per-axis quantized constants because after converting to scalar,
// we will lose per-axis qantization parameter.
if (element_type.isa<quant::UniformQuantizedPerAxisType>()) {
return false;
}
if (IsScalar(value)) {
return false;
}
return elements_attr->isSplat();
}
// If this type is a scalar shaped type.
bool IsScalar(mlir::Value value) const {
auto type = value.getType().dyn_cast<ShapedType>();
if (!type) {
return false;
}
if (!type.hasStaticShape()) {
return false;
}
return type.getNumElements() == 1;
}
// Returns true if we can fuse an affine op with consuming binary op.
bool CanFuseAffineOp(Operation *affine_op, Operation *binary_op) const {
if (!isa_and_nonnull<TFL::Conv2DOp, TFL::DepthwiseConv2DOp,
TFL::FullyConnectedOp>(affine_op)) {
return false;
}
DenseElementsAttr value;
// Check that bias are constants if not none.
Value bias = affine_op->getOperand(2);
if (!bias.getType().isa<NoneType>() &&
!matchPattern(bias, m_Constant(&value))) {
return false;
}
// If the binary op is mul/div, also check that filter is constant.
if (isa<TFL::MulOp, TFL::DivOp>(binary_op) &&
!matchPattern(affine_op->getOperand(1), m_Constant(&value))) {
return false;
}
// We can only fuse F32/BF16.
auto is_fusable_type = [](Type t) {
Type element_type = t;
if (auto shaped_type = t.dyn_cast<ShapedType>()) {
element_type = shaped_type.getElementType();
}
return element_type.isBF16() || element_type.isF32();
};
for (Type t : binary_op->getOperandTypes()) {
if (!is_fusable_type(t)) {
return false;
}
}
return true;
}
};
using ScalarizeSplatConstantForSub =
ScalarizeSplatConstantForBroadcastableOps<TFL::SubOp>;
using ScalarizeSplatConstantForAdd =
ScalarizeSplatConstantForBroadcastableOps<TFL::AddOp>;
using ScalarizeSplatConstantForMul =
ScalarizeSplatConstantForBroadcastableOps<TFL::MulOp>;
using ScalarizeSplatConstantForDiv =
ScalarizeSplatConstantForBroadcastableOps<TFL::DivOp>;
struct ConvertTrivialTransposeOpToReshapeOp
: public OpRewritePattern<TFL::TransposeOp> {
using OpRewritePattern<TFL::TransposeOp>::OpRewritePattern;
@ -818,6 +958,8 @@ void Optimize::runOnFunction() {
OwningRewritePatternList phase_2_patterns;
TFL::populateWithGenerated(ctx, phase_2_patterns);
phase_2_patterns.insert<
ScalarizeSplatConstantForAdd, ScalarizeSplatConstantForSub,
ScalarizeSplatConstantForMul, ScalarizeSplatConstantForDiv,
FuseFullyConnectedAndAdd, FuseFullyConnectedAndReluX<TFL::ReluOp, kRelu>,
FuseFullyConnectedAndReluX<TFL::Relu6Op, kRelu6>,
FuseFullyConnectedAndReluX<TFL::Relu1Op, kRelu1>,

View File

@ -94,6 +94,10 @@ std::unique_ptr<OperationPass<FuncOp>> CreateRuntimeVerifyPass();
// Creates raise custom ops pass, which legalize custom ops to TFL::CustomOp
std::unique_ptr<OperationPass<FuncOp>> CreateRaiseCustomOpsPass();
// Inserts an TFL::CallOnce op when the tf_saved_model's session initialzer is
// given.
std::unique_ptr<OperationPass<ModuleOp>>
CreateInsertCallOnceOpFromSessionInitializerPass();
} // namespace TFL
} // namespace mlir

View File

@ -139,6 +139,30 @@ struct RemoveVolatileOps : public OpRewritePattern<DequantizeOp> {
}
};
// Removes LSTMs that have dangling output.
// LSTMs are not removed automatically becuase they are stateful ops.
template <typename LstmOpTy>
struct PruneUnusedLstm : public OpRewritePattern<LstmOpTy> {
public:
explicit PruneUnusedLstm(MLIRContext* context)
: OpRewritePattern<LstmOpTy>(context) {}
LogicalResult matchAndRewrite(LstmOpTy lstm_op,
PatternRewriter& rewriter) const override {
Operation* op = lstm_op.getOperation();
if (op->isKnownTerminator()) {
return failure();
}
for (auto result : op->getOpResults()) {
if (!result.use_empty()) {
return failure();
}
}
rewriter.eraseOp(op);
return success();
}
};
#include "tensorflow/compiler/mlir/lite/transforms/generated_post_quantize.inc"
void PostQuantizePass::runOnFunction() {
@ -147,6 +171,7 @@ void PostQuantizePass::runOnFunction() {
auto* ctx = func.getContext();
TFL::populateWithGenerated(ctx, patterns);
patterns.insert<quant::FoldTrivalRequantizeOp<QuantizeOp>>(ctx);
patterns.insert<PruneUnusedLstm<TFL::UnidirectionalSequenceLSTMOp>>(ctx);
applyPatternsAndFoldGreedily(func, std::move(patterns));
if (!emit_quant_adaptor_ops_) {

View File

@ -14,6 +14,7 @@ limitations under the License.
==============================================================================*/
// This transformation pass applies quantization propagation on TFLite dialect.
#include <cmath>
#include <iterator>
#include <string>
@ -21,10 +22,13 @@ limitations under the License.
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SmallVector.h"
#include "llvm/ADT/StringRef.h"
#include "llvm/Support/Casting.h"
#include "llvm/Support/CommandLine.h"
#include "mlir/Dialect/Quant/QuantOps.h" // from @llvm-project
#include "mlir/IR/Function.h" // from @llvm-project
#include "mlir/IR/MLIRContext.h" // from @llvm-project
#include "mlir/IR/Operation.h" // from @llvm-project
#include "mlir/IR/PatternMatch.h" // from @llvm-project
#include "mlir/IR/Value.h" // from @llvm-project
#include "mlir/Pass/Pass.h" // from @llvm-project
#include "mlir/Transforms/GreedyPatternRewriteDriver.h" // from @llvm-project
@ -305,6 +309,52 @@ bool PrepareQuantizePass::ContainsQuantizeOps(FuncOp func) {
using PrepareQuantStats =
quant::ConvertStatsToQDQs<quant::QuantizeCastOp, quant::DequantizeCastOp>;
// Calculates the minimum power of two that is not less than the value.
double power_of_two_bound(double value) {
return std::pow(2, std::ceil(std::log2(value)));
}
// Quantize recurrent input of LSTM with 16 bits.
template <typename SourceOp, typename Q, typename DQ>
struct ConvertLstmStatsToQDQs : public OpRewritePattern<SourceOp> {
public:
explicit ConvertLstmStatsToQDQs(MLIRContext* context)
: OpRewritePattern<SourceOp>(context, /*benefit=*/2) {}
LogicalResult matchAndRewrite(SourceOp op,
PatternRewriter& rewriter) const override {
quant::StatisticsOp stats_op = llvm::dyn_cast_or_null<quant::StatisticsOp>(
op.input_cell_state().getDefiningOp());
// Recurrent input is be used within an LSTM, and thus should have one use.
if (!stats_op || !stats_op.getResult().hasOneUse()) {
return failure();
}
auto stats = stats_op.layerStats().dyn_cast<DenseFPElementsAttr>();
if (!stats) {
return failure();
}
double max = std::max(
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({0}))),
std::abs(FloatAttr::getValueAsDouble(stats.getValue<APFloat>({1}))));
double bound = power_of_two_bound(max);
Type expressed = stats_op.getType().cast<ShapedType>().getElementType();
// maximum value is adjusted to get a scale of power_of_two(max)/32768.
quant::QuantizedType quant_type = quant::fakeQuantAttrsToType(
stats_op.getLoc(), 16, -bound, bound * 32767.0 / 32768.0,
/*narrow_range*/ false, expressed, /*is_signed*/ true);
rewriter.setInsertionPointAfter(stats_op);
Type result_type = quant_type.castFromExpressedType(stats_op.getType());
auto q = rewriter.create<Q>(stats_op.getLoc(), result_type, stats_op.arg());
rewriter.replaceOpWithNewOp<DQ>(stats_op, stats_op.getType(), q);
return success();
}
};
using PrepareLstmQuantStats =
ConvertLstmStatsToQDQs<TFL::UnidirectionalSequenceLSTMOp,
quant::QuantizeCastOp, quant::DequantizeCastOp>;
void PrepareQuantizePass::runOnFunction() {
FuncOp func = getFunction();
MLIRContext* ctx = func.getContext();
@ -326,7 +376,14 @@ void PrepareQuantizePass::runOnFunction() {
OwningRewritePatternList patterns;
bool is_signed = quant_specs_.IsSignedInferenceType();
int bit_width = quant_specs_.GetQuantizationTypeWidth();
bool enforce_fixed_output_range = ContainsQuantizeOps(func);
bool quantization_aware_training_mode = ContainsQuantizeOps(func);
// Enforce fixed output range for post-training quantization and
// when the model has quantization emulation ops, unless it was disabled
// explicitly by the flag.
bool enforced_output_range =
(quant_specs_.post_training_quantization ||
quantization_aware_training_mode) &&
!quant_specs_.disable_enforced_fixed_output_range;
if (is_signed) {
patterns.insert<quant::ConvertUnsignedToSigned<quant::QuantizeCastOp>>(ctx);
// Convert quant stats to int8 quantization parameters.
@ -337,6 +394,7 @@ void PrepareQuantizePass::runOnFunction() {
// Currently, only activation stats are imported, so narrow_range = false.
patterns.insert<PrepareQuantStats>(bit_width, false, false, ctx);
}
patterns.insert<PrepareLstmQuantStats>(ctx);
applyPatternsAndFoldGreedily(func, std::move(patterns));
SanityCheckAndAdjustment(func);
@ -345,8 +403,7 @@ void PrepareQuantizePass::runOnFunction() {
// values (tensors).
ApplyQuantizationParamsPropagation(
func, is_signed, disable_per_channel || quant_specs_.disable_per_channel,
GetOpQuantSpec,
enforce_fixed_output_range || quant_specs_.post_training_quantization);
GetOpQuantSpec, enforced_output_range);
ConvertMlirQuantOpsToTFLQuantOps(func);
}

View File

@ -51,7 +51,7 @@ static ConfigProto::Experimental::MlirBridgeRollout GetUserRequest(
}
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
absl::optional<ConfigProto> config_proto) {
const tensorflow::Graph& graph, absl::optional<ConfigProto> config_proto) {
switch (GetUserRequest(config_proto)) {
case ConfigProto::Experimental::MLIR_BRIDGE_ROLLOUT_ENABLED:
return MlirBridgeRolloutPolicy::kEnabledByUser;

View File

@ -17,6 +17,7 @@ limitations under the License.
#define THIRD_PARTY_TENSORFLOW_COMPILER_MLIR_MLIR_BRIDGE_ROLLOUT_POLICY_H_
#include "absl/types/optional.h"
#include "tensorflow/core/graph/graph.h"
#include "tensorflow/core/protobuf/config.pb.h"
namespace tensorflow {
@ -46,6 +47,7 @@ enum class MlirBridgeRolloutPolicy {
// The config_proto param is a required input for all TF1 graphs but it is
// redundant for TF2 graphs.
MlirBridgeRolloutPolicy GetMlirBridgeRolloutPolicy(
const tensorflow::Graph& graph,
absl::optional<tensorflow::ConfigProto> config_proto);
} // namespace tensorflow

View File

@ -15,6 +15,7 @@ limitations under the License.
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include <memory>
#include <string>
#include "absl/container/flat_hash_set.h"
@ -32,10 +33,19 @@ limitations under the License.
#include "tensorflow/compiler/mlir/tensorflow/utils/device_util.h"
#include "tensorflow/compiler/mlir/tensorflow/utils/dump_mlir_util.h"
#include "tensorflow/core/common_runtime/graph_constructor.h"
#include "tensorflow/core/lib/monitoring/counter.h"
#include "tensorflow/core/platform/status.h"
#include "tensorflow/core/public/session_options.h"
namespace tensorflow {
auto* shadow_run_success = monitoring::Counter<0>::New(
"/tensorflow/mlir/shadow_run_success", "Success count of MLIR shadow runs");
auto* shadow_run_failure = monitoring::Counter<2>::New(
"/tensorflow/mlir/shadow_run_failure", "Failure count of MLIR shadow runs",
"kind", "name");
static inline absl::string_view StringRefToView(llvm::StringRef ref) {
return {ref.data(), ref.size()};
}
@ -109,7 +119,7 @@ Status MlirFunctionOptimizationPass::Run(
// Skip conversion from Graph to MLIR if none of the passes are enabled.
const bool is_enabled =
llvm::any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
return pass_registration.pass->IsEnabled(config_proto);
return pass_registration.pass->IsEnabled(config_proto, **graph);
});
if (!is_enabled) {
@ -123,6 +133,17 @@ Status MlirFunctionOptimizationPass::Run(
<< "(registered " << registry_->passes().size()
<< " passes)";
// For scenarios when the new bridge is enabled by analysis we need to make
// sure that MLIR transformations are executed in a shadow mode.
// In this case, no changes should be done to the original `graph`
// and no failures propagated to the user.
bool enabled_by_analysis =
mlir_rollout_policy_(**graph, config_proto) ==
MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis;
if (enabled_by_analysis) {
LOG_FIRST_N(INFO, 1) << "Shadow run of MLIR enabled after graph analysis";
}
GraphDebugInfo debug_info;
mlir::MLIRContext context;
RegisterDialects(context.getDialectRegistry());
@ -130,10 +151,21 @@ Status MlirFunctionOptimizationPass::Run(
import_config.graph_as_function = true;
import_config.control_outputs = *control_ret_node_names;
import_config.upgrade_legacy = true;
TF_ASSIGN_OR_RETURN(auto module_ref,
ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context));
auto module_ref_status = ConvertGraphToMlir(**graph, debug_info, *flib_def,
import_config, &context);
if (!module_ref_status.ok()) {
if (enabled_by_analysis) {
shadow_run_failure->GetCell("graph_to_mlir", "")->IncrementBy(1);
// Do not fail, let the old bridge to run on the original `graph`.
return Status::OK();
}
return module_ref_status.status();
}
auto module_ref = std::move(module_ref_status.ValueOrDie());
AddDevicesToOp(*module_ref, &device_set);
for (auto& pass_registration : registry_->passes()) {
@ -144,7 +176,17 @@ Status MlirFunctionOptimizationPass::Run(
DumpModule(*module_ref, llvm::formatv("mlir_{0}_before_", name));
}
TF_RETURN_IF_ERROR(pass_registration.pass->Run(config_proto, *module_ref));
auto pass_status =
pass_registration.pass->Run(config_proto, *module_ref, **graph);
if (!pass_status.ok()) {
if (enabled_by_analysis) {
shadow_run_failure->GetCell("pass", name.str())->IncrementBy(1);
// Do not fail, let the old bridge to run on the original `graph`.
return Status::OK();
}
return pass_status;
}
if (VLOG_IS_ON(1)) {
DumpModule(*module_ref, llvm::formatv("mlir_{0}_after_", name));
@ -153,6 +195,25 @@ Status MlirFunctionOptimizationPass::Run(
GraphExportConfig export_config;
absl::flat_hash_set<Node*> control_ret_nodes;
// In case MLIR is enabled by analysis, verify that MLIR could be converted
// back to TF graph. Original `graph` must stay the same.
if (enabled_by_analysis) {
auto empty_graph = std::make_unique<Graph>(OpRegistry::Global());
FunctionLibraryDefinition empty_flib = empty_graph->flib_def();
auto mlir_to_graph_status =
ConvertMlirToGraph(*module_ref, export_config, &empty_graph,
&empty_flib, &control_ret_nodes);
if (mlir_to_graph_status.ok()) {
shadow_run_success->GetCell()->IncrementBy(1);
} else {
shadow_run_failure->GetCell("mlir_to_graph", "")->IncrementBy(1);
}
return Status::OK();
}
TF_RETURN_WITH_CONTEXT_IF_ERROR(
ConvertMlirToGraph(*module_ref, export_config, graph, flib_def,
&control_ret_nodes),
@ -183,7 +244,7 @@ Status MlirV1CompatGraphOptimizationPass::Run(
const bool is_enabled =
absl::c_any_of(registry_->passes(), [&](auto& pass_registration) -> bool {
return pass_registration.pass->IsEnabled(
options.session_options->config);
options.session_options->config, **options.graph);
});
if (!is_enabled) {

View File

@ -16,6 +16,9 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
#define TENSORFLOW_COMPILER_MLIR_MLIR_GRAPH_OPTIMIZATION_PASS_H_
#include <functional>
#include "tensorflow/compiler/mlir/mlir_bridge_rollout_policy.h"
#include "mlir/IR/Module.h" // from @llvm-project
#include "tensorflow/core/common_runtime/function_optimization_registry.h"
#include "tensorflow/core/common_runtime/optimization_registry.h"
@ -34,10 +37,11 @@ class MlirOptimizationPass {
public:
virtual ~MlirOptimizationPass() = default;
virtual llvm::StringRef name() const = 0;
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
virtual bool IsEnabled(const ConfigProto& config_proto,
const Graph& graph) const = 0;
virtual Status Run(const ConfigProto& config_proto,
mlir::ModuleOp module) = 0;
virtual Status Run(const ConfigProto& config_proto, mlir::ModuleOp module,
const Graph& graph) = 0;
};
class MlirOptimizationPassRegistry {
@ -59,10 +63,14 @@ class MlirOptimizationPassRegistry {
// Returns the global registry of MLIR optimization passes.
static MlirOptimizationPassRegistry& Global();
// Register optimization `pass` with the given `priority`.
void Add(int priority, std::unique_ptr<MlirOptimizationPass> pass) {
passes_.insert({priority, std::move(pass)});
}
// Free the memory allocated for all passes.
void ClearPasses() { passes_.clear(); }
const Passes& passes() const { return passes_; }
private:
@ -75,8 +83,11 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
public:
explicit MlirFunctionOptimizationPass(
const MlirOptimizationPassRegistry* registry =
&MlirOptimizationPassRegistry::Global())
: registry_(registry) {}
&MlirOptimizationPassRegistry::Global(),
std::function<MlirBridgeRolloutPolicy(const Graph& graph,
absl::optional<ConfigProto>)>
mlir_rollout_policy = GetMlirBridgeRolloutPolicy)
: registry_(registry), mlir_rollout_policy_(mlir_rollout_policy) {}
Status Run(const DeviceSet& device_set, const ConfigProto& config_proto,
std::unique_ptr<Graph>* graph, FunctionLibraryDefinition* flib_def,
@ -85,6 +96,9 @@ class MlirFunctionOptimizationPass : public FunctionOptimizationPass {
private:
const MlirOptimizationPassRegistry* registry_;
std::function<MlirBridgeRolloutPolicy(
const tensorflow::Graph& graph, absl::optional<tensorflow::ConfigProto>)>
mlir_rollout_policy_;
};
// -------------------------------------------------------------------------- //
@ -100,7 +114,8 @@ class MlirV1CompatOptimizationPass {
public:
virtual ~MlirV1CompatOptimizationPass() = default;
virtual llvm::StringRef name() const = 0;
virtual bool IsEnabled(const ConfigProto& config_proto) const = 0;
virtual bool IsEnabled(const ConfigProto& config_proto,
const Graph& graph) const = 0;
virtual Status Run(const GraphOptimizationPassOptions& options,
mlir::ModuleOp module) = 0;

View File

@ -0,0 +1,121 @@
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/mlir/mlir_graph_optimization_pass.h"
#include <memory>
#include "tensorflow/core/platform/test.h"
namespace tensorflow {
using ::testing::_;
using ::testing::NiceMock;
using ::testing::Return;
using ::testing::Test;
class MockMlirOptimizationPass : public MlirOptimizationPass {
public:
MOCK_METHOD(llvm::StringRef, name, (), (const, override));
MOCK_METHOD(bool, IsEnabled,
(const ConfigProto& config_proto, const Graph& graph),
(const, override));
MOCK_METHOD(Status, Run,
(const ConfigProto& config_proto, mlir::ModuleOp module,
const Graph& graph),
(override));
};
class MlirGraphOptimizationPassTest : public Test {
public:
void Init(MlirBridgeRolloutPolicy rollout_policy, Status pass_run_result) {
graph_ = std::make_unique<Graph>(OpRegistry::Global());
function_optimization_pass_ = MlirFunctionOptimizationPass(
&MlirOptimizationPassRegistry::Global(),
[rollout_policy](const Graph& graph, absl::optional<ConfigProto>) {
return rollout_policy;
});
auto optimization_pass =
std::make_unique<NiceMock<MockMlirOptimizationPass>>();
EXPECT_CALL(*optimization_pass, IsEnabled(_, _))
.WillRepeatedly(Return(true));
EXPECT_CALL(*optimization_pass, Run(_, _, _))
.WillOnce(Return(pass_run_result));
MlirOptimizationPassRegistry::Global().Add(0, std::move(optimization_pass));
flib_.reset(new FunctionLibraryDefinition(graph_->flib_def()));
}
void TearDown() override {
MlirOptimizationPassRegistry::Global().ClearPasses();
}
ConfigProto config_proto_;
MlirFunctionOptimizationPass function_optimization_pass_;
DeviceSet device_set_;
std::unique_ptr<Graph> graph_;
std::unique_ptr<FunctionLibraryDefinition> flib_;
std::vector<std::string> control_ret_node_names_;
bool control_rets_updated_{false};
};
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsNoShadow) {
Init(MlirBridgeRolloutPolicy::kEnabledByUser,
Status(error::Code::ABORTED, "aborted"));
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status(error::Code::ABORTED, "aborted"));
// Proto matchers might be unavailable.
#if defined(PLATFORM_GOOGLE)
GraphDef resulted_graph_def;
graph_->ToGraphDef(&resulted_graph_def);
EXPECT_THAT(resulted_graph_def,
::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def)));
#endif
}
TEST_F(MlirGraphOptimizationPassTest, OptimizationPassFailsShadow) {
Init(MlirBridgeRolloutPolicy::kEnabledAfterGraphAnalysis,
Status(error::Code::ABORTED, "aborted"));
GraphDef original_graph_def;
graph_->ToGraphDef(&original_graph_def);
EXPECT_EQ(function_optimization_pass_.Run(
device_set_, config_proto_, &graph_, flib_.get(),
&control_ret_node_names_, &control_rets_updated_),
Status::OK());
// Proto matchers might be unavailable.
#if defined(PLATFORM_GOOGLE)
GraphDef resulted_graph_def;
graph_->ToGraphDef(&resulted_graph_def);
EXPECT_THAT(resulted_graph_def,
::testing::proto::IgnoringRepeatedFieldOrdering(
::testing::EquivToProto(original_graph_def)));
#endif
}
} // namespace tensorflow

View File

@ -604,6 +604,29 @@ If `condition` evaluates to false, print the list of tensors in `data`.
let hasCanonicalizer = 1;
}
def TF_AssignOp : TF_Op<"Assign", [NoSideEffect]> {
let summary = "Update 'ref' by assigning 'value' to it.";
let description = [{
This operation outputs "ref" after the assignment is done.
This makes it easier to chain operations that need to use the reset value.
}];
let arguments = (ins
TF_Tensor:$ref,
TF_Tensor:$value,
DefaultValuedAttr<BoolAttr, "true">:$validate_shape,
DefaultValuedAttr<BoolAttr, "true">:$use_locking
);
let results = (outs
TF_Tensor:$output_ref
);
TF_DerivedOperandTypeAttr T = TF_DerivedOperandTypeAttr<0>;
}
def TF_AssignAddVariableOp : TF_Op<"AssignAddVariableOp", []> {
let summary = "Adds a value to the current value of a variable.";
@ -12629,6 +12652,8 @@ retained with length 1.
OpBuilderDAG<(ins "Value":$input, "Value":$reduction_indices,
"BoolAttr":$keep_dims)>
];
let hasFolder = 1;
}
def TF_SymbolicGradientOp : TF_Op<"SymbolicGradient", [NoSideEffect]> {

View File

@ -684,6 +684,7 @@ body: A function that takes a list of tensors and returns another
FlatSymbolRefAttr:$cond,
FlatSymbolRefAttr:$body,
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
// Used to map StatelessWhile and While op defined in TensorFlow to a common
@ -696,12 +697,10 @@ body: A function that takes a list of tensors and returns another
);
TF_DerivedOperandTypeListAttr T = TF_DerivedOperandTypeListAttr<0>;
TF_DerivedResultShapeListAttr output_shapes = TF_DerivedResultShapeListAttr<0>;
let verifier = [{
return Verify(*this);
}];
let hasCanonicalizer = 1;
let extraClassDeclaration = [{
// Get the condition function.
@ -755,8 +754,9 @@ def TF_WhileRegionOp : TF_Op<"WhileRegion",
// Used to map StatelessWhile and While op defined in TensorFlow to a common
// op.
DefaultValuedAttr<BoolAttr, "false">:$is_stateless,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations
DefaultValuedAttr<TF_ShapeAttrArray, "{}">:$output_shapes,
DefaultValuedAttr<I64Attr, "10">:$parallel_iterations,
DefaultValuedAttr<BoolAttr, "false">:$is_stateless
);
let results = (outs Variadic<AnyTensor>:$output);

View File

@ -1539,6 +1539,21 @@ void SumOp::build(OpBuilder &builder, OperationState &result, Value input,
build(builder, result, out_ty, input, reduction_indices, keep_dims);
}
// TODO: Templatize this fold for all reduction ops.
OpFoldResult SumOp::fold(ArrayRef<Attribute> operands) {
auto input_ty = input().getType().template dyn_cast<RankedTensorType>();
if (!input_ty) return {};
auto result_ty = getType().template dyn_cast<RankedTensorType>();
if (!result_ty) return {};
// Bypass this op if the result has the same shape and type. This can happen
// if the input tensor has size 0 or size 1.
if (!keep_dims() && input_ty == result_ty) {
return input();
}
return {};
}
//===----------------------------------------------------------------------===//
// StridedSliceOp
//===----------------------------------------------------------------------===//
@ -2590,14 +2605,6 @@ static LogicalResult Verify(WhileOp op) {
return success();
}
//===----------------------------------------------------------------------===//
// WhileOp canonicalization.
//===----------------------------------------------------------------------===//
void WhileOp::getCanonicalizationPatterns(OwningRewritePatternList &results,
MLIRContext *context) {
results.insert<DropAttributes<WhileOp>>(context);
}
//===----------------------------------------------------------------------===//
// WhileRegionOp
//===----------------------------------------------------------------------===//

View File

@ -82,25 +82,28 @@ static LogicalResult Verify(SessionInitializerOp session_initializer) {
mlir::SymbolTable symbol_table(
session_initializer.getParentOfType<ModuleOp>());
auto init_func_op =
symbol_table.lookup<mlir::FuncOp>(session_initializer.initializer());
if (!init_func_op)
return session_initializer.emitOpError()
<< "the initializer function does not exist";
for (auto sym_ref : session_initializer.initializers()) {
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
sym_ref.cast<FlatSymbolRefAttr>().getValue());
if (!init_func_op.getType().getResults().empty())
return session_initializer.emitOpError()
<< "the initializer function should have no output";
if (!init_func_op)
return session_initializer.emitOpError()
<< "the initializer function does not exist";
auto exported_names = GetExportedNames(init_func_op);
if (!init_func_op.getType().getResults().empty())
return session_initializer.emitOpError()
<< "the initializer function should have no output";
if (exported_names.empty())
return session_initializer.emitOpError()
<< "the initializer function should be exported";
auto exported_names = GetExportedNames(init_func_op);
if (exported_names.size() != 1)
return session_initializer.emitOpError()
<< "the initializer function should have only one exported names";
if (exported_names.empty())
return session_initializer.emitOpError()
<< "the initializer function should be exported";
if (exported_names.size() != 1)
return session_initializer.emitOpError()
<< "the initializer function should have only one exported names";
}
return success();
}
@ -291,7 +294,11 @@ static LogicalResult VerifySavedModelModule(
auto is_init = [&session_initializers](mlir::FuncOp func) {
if (session_initializers.empty()) return false;
return (*session_initializers.begin()).initializer() == func.getName();
auto init_syms = (*session_initializers.begin()).initializers();
return std::any_of(
init_syms.begin(), init_syms.end(), [&](Attribute sym_ref) {
return sym_ref.cast<FlatSymbolRefAttr>().getValue() == func.getName();
});
};
SymbolTable symbol_table(module);
@ -450,22 +457,36 @@ class OptimizeSessionInitializerPattern
LogicalResult matchAndRewrite(SessionInitializerOp op,
PatternRewriter &rewriter) const override {
SymbolTable symbol_table(op.getParentOfType<ModuleOp>());
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(op.initializer());
// The init function can only be referenced from the SessionInitializerOp.
// And there is at most one SessionInitializerOp in the module. So if both
// ops have no other uses or have one NoOp only, they can be simply erased.
auto &operations = init_func_op.front().getOperations();
if ((operations.size() == 1 && operations.front().isKnownTerminator()) ||
(operations.size() == 2 &&
dyn_cast<mlir::TF::NoOp>(operations.front()) &&
operations.back().isKnownTerminator())) {
rewriter.eraseOp(init_func_op);
rewriter.eraseOp(op);
return success();
SmallVector<FuncOp, 2> to_remove;
SmallVector<mlir::Attribute, 2> to_keep;
for (auto sym_ref : op.initializers()) {
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
sym_ref.cast<FlatSymbolRefAttr>().getValue());
// The init function can only be referenced from the SessionInitializerOp.
// And there is at most one SessionInitializerOp in the module. So if both
// ops have no other uses or have one NoOp only, they can be simply
// erased.
auto &operations = init_func_op.front().getOperations();
if ((operations.size() == 1 && operations.front().isKnownTerminator()) ||
(operations.size() == 2 &&
dyn_cast<mlir::TF::NoOp>(operations.front()) &&
operations.back().isKnownTerminator())) {
to_remove.push_back(init_func_op);
} else {
to_keep.push_back(sym_ref);
}
}
return failure();
for (auto func_op : to_remove) rewriter.eraseOp(func_op);
if (to_keep.empty())
rewriter.eraseOp(op);
else
op.setAttr("initializers", rewriter.getArrayAttr(to_keep));
return success();
}
};
@ -474,15 +495,22 @@ void SessionInitializerOp::getCanonicalizationPatterns(
results.insert<OptimizeSessionInitializerPattern>(context);
}
llvm::Optional<StringRef> GetSessionInitializerExportedName(ModuleOp op) {
SmallVector<StringRef, 2> GetSessionInitializerExportedName(ModuleOp op) {
auto session_initializer_op = GetSessionInitializerOp(op);
if (!session_initializer_op) return llvm::None;
if (!session_initializer_op) return {};
SymbolTable symbol_table(op);
auto init_func_op =
symbol_table.lookup<mlir::FuncOp>(session_initializer_op.initializer());
auto exported_names = GetExportedNames(init_func_op);
return exported_names[0];
SmallVector<StringRef, 2> results;
for (auto sym_ref : session_initializer_op.initializers()) {
auto init_func_op = symbol_table.lookup<mlir::FuncOp>(
sym_ref.cast<FlatSymbolRefAttr>().getValue());
auto exported_names = GetExportedNames(init_func_op);
assert(exported_names.size() == 1);
results.push_back(exported_names[0]);
}
return results;
}
} // namespace tf_saved_model

View File

@ -81,7 +81,7 @@ Type GetBoundInputArgTypeFor(mlir::Operation *op);
SessionInitializerOp GetSessionInitializerOp(mlir::ModuleOp op);
// Returns the exported name for the session initializer function.
llvm::Optional<StringRef> GetSessionInitializerExportedName(mlir::ModuleOp op);
SmallVector<StringRef, 2> GetSessionInitializerExportedName(mlir::ModuleOp op);
} // namespace tf_saved_model
} // namespace mlir

View File

@ -132,13 +132,13 @@ def TfSavedModel_GlobalTensorOp : TfSavedModel_Op<"global_tensor"> {
def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
let summary = "Initializes TensorFlow session state.";
let description = [{
The session initializer op marks a function that must be called by an
external agent exactly once to initialize TensorFlow session state, and this
must happen before any other exported functions are called. There must be no
more than one session initializer in a saved model.
The session initializer op marks one or more functions that must be called
by an external agent exactly once to initialize TensorFlow session state,
and this must happen before any other exported functions are called. There
must be no more than one session initializer op in a saved model.
The `initializer` represents the initialization function. The function have
no output and this function should be only called once.
The `initializers` represents the initialization functions. The function
have no output and this function should be only called once.
This is used, for example, to initialize hash tables stored in resources and
accessed by resource name (rather than as resource handles or bound inputs
@ -146,7 +146,7 @@ def TfSavedModel_SessionInitializerOp: TfSavedModel_Op<"session_initializer"> {
}];
let arguments = (ins
FlatSymbolRefAttr:$initializer
SymbolRefArrayAttr:$initializers
);
@ -160,7 +160,7 @@ def TfSavedModel_AssetOp: TfSavedModel_Op<"asset", [Symbol]> {
let description = [{
Represents an asset in the saved model that points to an external file. It
is a scalar string tensor and it is passed as an argument to the session
initializer function.
initializer functions.
The `sym_name` represents the symbol table name used for internal IR
references.

View File

@ -1259,24 +1259,6 @@ func @testIfDropOutputShapes(tensor<i1>, tensor<2xf32>) -> tensor<2xf32> {
return %1 : tensor<2xf32>
}
// Check that output_shapes attribute is removed for tf.Whileß
func @testWhileCond(tensor<*xf32>) -> (tensor<i1>)
func @testWhileBody(tensor<*xf32>) -> (tensor<*xf32>)
// CHECK-LABEL: func @testWhileDropOutputShapes
func @testWhileDropOutputShapes(tensor<*xf32>) -> (tensor<*xf32>) {
^bb0(%arg0: tensor<*xf32>):
// CHECK: "tf.While"
// CHECK-NOT: output_shapes
%1 = "tf.While"(%arg0) {
cond = @testWhileCond,
body = @testWhileBody,
is_stateless = false,
output_shapes = [#tf.shape<>]
} : (tensor<*xf32>) -> (tensor<*xf32>)
return %1 : tensor<*xf32>
}
// CHECK-LABEL: testNMSV3ToNMSV4
func @testNMSV3ToNMSV4(%arg0: tensor<3x4xf32>, %arg1: tensor<3xf32>, %arg2: tensor<f32>, %arg3: tensor<f32>) -> tensor<2xi32> {
%max_size = constant dense<2> : tensor<i32>
@ -1291,3 +1273,10 @@ func @testFusedBatchNormToBatchNormV3(%arg0: tensor<8x8x8x8xf32>, %arg1: tensor<
%0:5 = "tf.FusedBatchNorm"(%arg0, %arg1, %arg2, %arg3, %arg4): (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>) -> (tensor<8x8x8x8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32>, tensor<8xf32> )
return %0#0 : tensor<8x8x8x8xf32>
}
// CHECK-LABEL: func @testSumFoldBypass
func @testSumFoldBypass(%arg0: tensor<4x?xf16>, %arg1: tensor<*xi64>) -> tensor<4x?xf16> {
// CHECK: return %arg0
%0 = "tf.Sum"(%arg0, %arg1) { keep_dims = false }: (tensor<4x?xf16>, tensor<*xi64>) -> tensor<4x?xf16>
return %0 : tensor<4x?xf16>
}

View File

@ -24,9 +24,8 @@ func @single_cluster(%arg0: tensor<?xi32>) -> tensor<?xi32> {
return %0 : tensor<?xi32>
}
// CHECK: func @[[CLUSTER]]
// CHECK: func private @[[CLUSTER]]
// CHECK-SAME: (%[[CLUSTER_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
// CHECK-SAME: sym_visibility = "private"
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_ARG_0]])
// CHECK: return %[[B_OUTPUT]]
@ -67,12 +66,12 @@ func @multiple_clusters(%arg0: tensor<?xi32>) -> tensor<?xi32> {
return %0 : tensor<?xi32>
}
// CHECK: func @[[CLUSTER_0]]
// CHECK: func private @[[CLUSTER_0]]
// CHECK-SAME: (%[[CLUSTER_0_ARG_0:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[B_OUTPUT:[0-9]*]] = "tf.B"(%[[CLUSTER_0_ARG_0]])
// CHECK: return %[[B_OUTPUT]]
// CHECK: func @[[CLUSTER_1]]
// CHECK: func private @[[CLUSTER_1]]
// CHECK-SAME: (%[[CLUSTER_1_ARG_0:[a-z0-9]*]]: tensor<?xi32>, %[[CLUSTER_1_ARG_1:[a-z0-9]*]]: tensor<?xi32>) -> tensor<?xi32>
// CHECK: %[[E_OUTPUT:[0-9]*]] = "tf.E"(%[[CLUSTER_1_ARG_0]])
// CHECK: %[[F_OUTPUT:[0-9]*]] = "tf.F"(%[[CLUSTER_1_ARG_1]], %[[E_OUTPUT]])
@ -98,7 +97,7 @@ func @cluster_operands(%arg0: tensor<?xi32>) -> tensor<?xi32> {
return %0 : tensor<?xi32>
}
// CHECK: func @[[CLUSTER]]
// CHECK: func private @[[CLUSTER]]
// CHECK-SAME: () -> tensor<?xi32>
// CHECK: %[[A_OUTPUT:[0-9]*]] = "tf.A"()
// CHECK: return %[[A_OUTPUT]]

View File

@ -47,11 +47,11 @@ func @func2(%arg0 : tensor<i1>) -> tensor<i1> {
// CHECK: module
// CHECK-SAME: @_tpu_v1_compat_outlined
// CHECK-LABEL: func @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> tensor<i1>
// CHECK-LABEL: func nested @_tpu_v1_compat_outlined_func0(%arg0: tensor<i1>) -> tensor<i1>
// CHECK-NEXT: tf.TPUReplicateMetadata
// CHECK-NEXT: tf.opA
// CHECK-LABEL: func @_tpu_v1_compat_outlined_func1(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<i1>, tensor<i32>)
// CHECK-LABEL: func nested @_tpu_v1_compat_outlined_func1(%arg0: tensor<i1>, %arg1: tensor<f32>) -> (tensor<i1>, tensor<i32>)
// CHECK-NEXT: tf.TPUReplicateMetadata
// CHECK-NEXT: tf.opA
// CHECK-NEXT: tf.opA

View File

@ -27,14 +27,14 @@ func @foo() {
// In the newly cloned function, check that we have a _tf.If operation and capture the then and else branch.
// CHECK: func @[[FUNCTIONALIZE_FUNC]]
// CHECK: func private @[[FUNCTIONALIZE_FUNC]]
// CHECK: "tf.If"
// CHECK-SAME: else_branch = @[[ELSE_FUNC:[A-Za-z0-9_]*]]
// CHECK-SAME: then_branch = @[[THEN_FUNC:[A-Za-z0-9_]*]]
// We expect the _tf.Add in the else func and the _tf.Mul in the then func
// CHECK: func @[[ELSE_FUNC]]
// CHECK: func private @[[ELSE_FUNC]]
// CHECK: "tf.Add"
// CHECK: func @[[THEN_FUNC]]
// CHECK: func private @[[THEN_FUNC]]
// CHECK: "tf.Mul"

View File

@ -40,7 +40,7 @@ library {
}
}
# Drop the control dependency on arg for the node "test"
# CHECK-LABEL: func @foo
# CHECK-LABEL: func private @foo
# CHECK: tf_executor.island wraps "tf.Const"()
node_def {
name: "test"

View File

@ -80,6 +80,6 @@ versions {
# CHECK-SAME: f = @[[FUNCTION:[a-zA-Z0-9_]*]]
# Verify that callee has the unit attribute tf._input_shapes.
# CHECK: func @[[FUNCTION]]
# CHECK: func private @[[FUNCTION]]
# CHECK: attributes
# CHECK-SAME: tf._input_shapes{{[,}]}}

View File

@ -90,6 +90,6 @@ library {
# CHECK: tf.HashTableV2
# CHECK-SAME: shared_name = "hash_table_node"
# CHECK: func @create_resource
# CHECK: func private @create_resource
# CHECK: tf.HashTableV2
# CHECK-SAME: shared_name = "hash_table_node@create_resource"

View File

@ -49,5 +49,5 @@ library {
}
}
# CHECK-DAG: func @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}
# CHECK-DAG: func private @custom_relu{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.relu, {}>}
# CHECK-DAG: func private @custom_embedding_matmul{{[0-9]*}}(){{.+}}tf._implements = #tf.func<@tensorflow.embedding_matmul, {key1 = 2 : i64, key2 = false}>}

View File

@ -13,7 +13,7 @@
# CHECK: %[[ISLAND_2:.*]], %[[ISLAND_2_control:.*]] = tf_executor.island wraps "tf.StatefulPartitionedCall"
# CHECK-SAME: f = @[[FUNC:[a-z0-9]*]]
# CHECK: tf_executor.fetch %[[ISLAND_1]], %[[ISLAND_2]] : tensor<*xf32>, tensor<*xf32>
# CHECK: func @[[FUNC]](%arg0: tensor<*xf32> {tf._user_specified_name = "inputs"}, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32>
# CHECK: func private @[[FUNC]](%arg0: tensor<*xf32> {tf._user_specified_name = "inputs"}, %arg1: tensor<*x!tf.resource>) -> tensor<*xf32>
node {
name: "args_0"

View File

@ -55,4 +55,4 @@ versions {
# site (a numerical suffix may be appended).
# CHECK: "tf.LegacyCall"(%outputs) {_disable_call_shape_inference = false, device = "", f = @foo0}
# CHECK: func @foo0
# CHECK: func private @foo0

View File

@ -74,7 +74,7 @@ library {
}
# The attribute "experimental_ints_on_device" and the return type INT32
# ensure that kDeviceRetOp is used instead of kRetOp
# CHECK-LABEL: func @foo
# CHECK-LABEL: func private @foo
# CHECK: tf.experimental_ints_on_device = true
# CHECK: return %{{.*}} tensor<{{.*}}i32>
attr {

View File

@ -5,8 +5,8 @@
# Verify that the NameAttrList is properly turned into reference to functions on import
# CHECK: tf.Case
# CHECK-SAME: branches = [@[[FOO:[a-z0-9]+]], @[[BAR:[a-z0-9]+]]]
# CHECK-DAG: func @[[FOO]]()
# CHECK-DAG: func @[[BAR]]()
# CHECK-DAG: func private @[[FOO]]()
# CHECK-DAG: func private @[[BAR]]()
node {
name: "predicate"

View File

@ -3,7 +3,7 @@
# Verify that the _input_shapes attribute of the FunctionDef is respected.
# This also checks that the output type is correctly inferred based on
# that.
#CHECK: func @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
#CHECK: func private @identity_function0(%arg0: tensor<i32>) -> tensor<i32>
node {
name: "Placeholder"

View File

@ -124,5 +124,5 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo110}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @foo111}
# CHECK-LABEL: func @foo110() attributes {sym_visibility = "private"}
# CHECK-LABEL: func @foo111() attributes {sym_visibility = "private"}
# CHECK-LABEL: func private @foo110()
# CHECK-LABEL: func private @foo111()

View File

@ -91,7 +91,7 @@ library {
# CHECK-SAME: {_disable_call_shape_inference = true, device = "", f = @test_func_name0}
# CHECK: tf_executor.fetch
# CHECK: return
# CHECK: func @test_func_name0
# CHECK: func private @test_func_name0
# CHECK-SAME: tf._resource_arg_unique_id = 0
# CHECK-SAME: tf._resource_arg_unique_id = 0
# CHECK: tf_executor.graph

View File

@ -4,7 +4,7 @@
# links the function and its gradient. In MLIR a TF ops gradient function is
# added to its list of function attributes.
# CHECK: func @foo0(
# CHECK: func private @foo0(
# CHECK: tf.gradient = @foo_grad
node {

View File

@ -4,8 +4,8 @@
# functions with arg name that are the same as the graph input name
# CHECK: func @main(%arg0: tensor<{{.*}}i32>) -> tensor<{{.*}}i32>
# CHECK: func @while_body
# CHECK: func @while_cond
# CHECK: func private @while_body
# CHECK: func private @while_cond
node {
name: "input"

View File

@ -57,7 +57,7 @@ versions {
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = true, device = "", f = @foo0}
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
# CHECK-LABEL: func @foo0() attributes {sym_visibility = "private"}
# CHECK-LABEL: func private @foo0()
# CHECK: "tf.LegacyCall"() {_disable_call_shape_inference = false, device = "", f = @bar0}
# CHECK-LABEL: func @bar0() attributes {sym_visibility = "private"}
# CHECK-LABEL: func private @bar0()

View File

@ -106,5 +106,5 @@ versions {
# CHECK: func @main
# CHECK: "tf.PartitionedCall"()
# CHECK-SAME: f = @[[FUNCTION:[A-Za-z0-9_]*]]
# CHECK: func @[[FUNCTION]]() -> tensor<*xui8>
# CHECK: func private @[[FUNCTION]]() -> tensor<*xui8>
# CHECK: return {{.*}} : tensor<*xui8>

View File

@ -86,6 +86,6 @@ versions {
# CHECK-SAME: f = @[[FUNCTION_FOO:[a-zA-Z0-9_]*]]
# Find callee and verify it has the stateful attribute set.
# CHECK: func @[[FUNCTION_FOO]]
# CHECK: func private @[[FUNCTION_FOO]]
# CHECK-SAME: attributes
# CHECK-SAME: tf.signature.is_stateful

View File

@ -12,7 +12,7 @@ func @f() {
}
// CHECK: func @g()
// CHECK: func @[[NEWG]]() attributes {sym_visibility = "private"}
// CHECK: func private @[[NEWG]]()
func @g() {
return
}
@ -22,12 +22,12 @@ func @g() {
// CHECK-LABEL: func @f
// 2 copies of @g
// CHECK-DAG: func @g{{.*}}
// CHECK-DAG: func @g{{.*}}
// CHECK-DAG: func private @g{{.*}}
// 4 copies of @h
// CHECK-DAG: func @h{{.*}}
// CHECK-DAG: func @h{{.*}}
// CHECK-DAG: func @h{{.*}}
// CHECK-DAG: func @h{{.*}}
// CHECK-DAG: func private @h{{.*}}
// CHECK-DAG: func private @h{{.*}}
// CHECK-DAG: func private @h{{.*}}
func @f() {
call @g() : () -> ()
call @g() : () -> ()
@ -47,7 +47,7 @@ func @h() {
// -----
// Handle error case of infinite recursion.
// expected-error @+1 {{reached cloning limit}}
func @f() attributes {sym_visibility = "private"} {
func private @f() {
call @f() : () -> ()
call @f() : () -> ()
return

View File

@ -33,4 +33,4 @@ func @foo(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
// CHECK: "tf.Identity"([[CALL_RESULT_REG]])
// Match the function name
// CHECK: func @[[FUNCTION]]
// CHECK: func private @[[FUNCTION]]

View File

@ -231,11 +231,11 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
// CHECK-DAG: [[PAD00:%.+]] = "tf.Const"() {value = dense<0> : tensor<1x2xi64>}
// CHECK-DAG: [[ZERO_I32:%.+]] = "tf.Const"() {value = dense<0> : tensor<i32>}
// CHECK-DAG: [[ZERO_I64:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK-DAG: [[ONE_I64:%.+]] = "tf.Const"() {value = dense<1> : tensor<i64>}
// CHECK-DAG: [[FULL_PADDINGS:%.+]] = "tf.ConcatV2"([[PAD00]], %arg2, [[PAD00]], [[ZERO_I64]])
// CHECK-DAG: [[PAD_DEFAULT:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK-DAG: [[PADDED:%.+]] = "tf.PadV2"(%arg0, [[FULL_PADDINGS]], [[PAD_DEFAULT]])
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Sum"([[FULL_PADDINGS]], [[ONE_I64]])
// CHECK-DAG: [[PADDINGS:%.+]]:2 = "tf.Unpack"([[FULL_PADDINGS]]) {axis = 1 : i64}
// CHECK-DAG: [[PADDINGS_SUM:%.+]] = "tf.Add"([[PADDINGS]]#0, [[PADDINGS]]#1)
// CHECK-DAG: [[INPUT_SHAPE:%.+]] = "tf.Const"() {value = dense<[3, 5, 7, 10]> : tensor<4xi64>}
// CHECK-DAG: [[PADDED_SHAPE:%.+]] = "tf.Add"([[PADDINGS_SUM]], [[INPUT_SHAPE]])
// CHECK-DAG: [[PADDED_SHAPE_SPLITS:%.+]]:4 = "tf.Split"([[ZERO_I32]], [[PADDED_SHAPE]])
@ -256,14 +256,25 @@ func @fourdim_space_to_batch_nd(%input: tensor<3x5x7x10xf32>, %block_shape: tens
}
// Verify the result shape for the tf.PadV2 op.
// CHECK-LABEL: const_paddings_space_to_batch_nd
func @const_paddings_space_to_batch_nd(%arg0: tensor<1x8x2xf32>) -> (tensor<3x5x2xf32>) {
%0 = "tf.Const"() {value = dense<3> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = "tf.Const"() {value = dense<[[3, 4]]> : tensor<1x2xi32>} : () -> tensor<1x2xi32>
// CHECK: "tf.PadV2"
// CHECK-SAME: tensor<1x5x2xf32>
// CHECK-DAG: [[VAL0:%.+]] = "tf.Const"() {value = dense<[3, 5, 2]> : tensor<3xi64>}
// CHECK-DAG: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 5, 3, 2]> : tensor<4xi64>}
// CHECK-DAG: [[VAL2:%.+]] = "tf.Const"() {value = dense<{{\[\[}}0, 0], [3, 4], [0, 0{{\]\]}}> : tensor<3x2xi64>}
// CHECK-DAG: [[VAL3:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK-DAG: [[VAL4:%.+]] = "tf.Const"() {value = dense<[2, 0, 1, 3]> : tensor<4xi64>}
// CHECK-DAG: [[VAL5:%.+]] = "tf.PadV2"(%arg0, [[VAL2]], [[VAL3]])
// CHECK-SAME: tensor<1x15x2xf32>
// CHECK-DAG: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
// CHECK-DAG: [[VAL7:%.+]] = "tf.Transpose"([[VAL6]], [[VAL4]])
// CHECK-DAG: [[VAL8:%.+]] = "tf.Reshape"([[VAL7]], [[VAL0]])
%2 = "tf.SpaceToBatchND"(%arg0, %0, %1) : (tensor<1x8x2xf32>, tensor<1xi32>, tensor<1x2xi32>) -> tensor<3x5x2xf32>
// CHECK: return [[VAL8]]
return %2 : tensor<3x5x2xf32>
}
@ -757,3 +768,117 @@ func @lgamma(%arg0: tensor<4xf32>) -> tensor<4xf32> {
%0 = "tf.Lgamma"(%arg0) : (tensor<4xf32>) -> tensor<4xf32>
return %0 : tensor<4xf32>
}
// CHECK-LABEL: func @imag_resize_nearest
func @imag_resize_nearest(%arg0: tensor<1x7x7x1xi32>) -> tensor<1x3x3x1xi32> {
%shape = "tf.Const"() {device = "", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: [[VAL1:%.+]] = "tf.Const"() {value = dense<[1, 3, 3, 1]>
// CHECK: [[VAL2:%.+]] = "tf.Const"() {value = dense<[1, 49, 1]>
// CHECK: [[VAL3:%.+]] = "tf.Const"() {value = dense<[0, 2, 4, 14, 16, 18, 28, 30, 32]> : tensor<9xi32>}
// CHECK: [[VAL4:%.+]] = "tf.Reshape"(%arg0, [[VAL2]])
// CHECK: [[VAL5:%.+]] = "tf.GatherV2"([[VAL4]], [[VAL3]], [[VAL0]]) {batch_dims = 0 : i64}
// CHECK: [[VAL6:%.+]] = "tf.Reshape"([[VAL5]], [[VAL1]])
// CHECK: return [[VAL6]]
%resize = "tf.ResizeNearestNeighbor"(%arg0, %shape) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x7x7x1xi32>, tensor<2xi32>) -> tensor<1x3x3x1xi32>
return %resize: tensor<1x3x3x1xi32>
}
// CHECK-LABEL: func @imag_resize_nearest_dyn_img
func @imag_resize_nearest_dyn_img(%arg0: tensor<1x?x?x1xi32>) -> tensor<1x3x3x1xi32> {
%shape = "tf.Const"() {device = "", value = dense<3> : tensor<2xi32>} : () -> tensor<2xi32>
// CHECK: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: [[VAL1:%.+]] = "tf.Const"() {value = dense<[3, 1]> : tensor<2xi32>}
// CHECK: [[VAL2:%.+]] = "tf.Const"() {value = dense<9> : tensor<1xi32>}
// CHECK: [[VAL3:%.+]] = "tf.Const"() {value = dense<3> : tensor<1xi32>}
// CHECK: [[VAL4:%.+]] = "tf.Const"() {value = dense<[1, 3]> : tensor<2xi32>}
// CHECK: [[VAL5:%.+]] = "tf.Const"() {value = dense<[0.000000e+00, 1.000000e+00, 2.000000e+00]>
// CHECK: [[VAL6:%.+]] = "tf.Const"() {value = dense<3.000000e+00> : tensor<f32>}
// CHECK: [[VAL7:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK: [[VAL8:%.+]] = "tf.Shape"(%arg0)
// CHECK: [[VAL9:%.+]] = "tf.Cast"([[VAL8]])
// CHECK: [[VAL10:%.+]]:4 = "tf.Unpack"([[VAL9]]) {axis = 0 : i64}
// CHECK: [[VAL11:%.+]] = "tf.Mul"([[VAL10]]#1, [[VAL10]]#2)
// CHECK: [[VAL12:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL7]])
// CHECK: [[VAL13:%.+]] = "tf.ExpandDims"([[VAL10]]#3, [[VAL7]])
// CHECK: [[VAL14:%.+]] = "tf.ConcatV2"([[VAL12]], [[VAL3]], [[VAL3]], [[VAL13]], [[VAL7]])
// CHECK: [[VAL15:%.+]] = "tf.Cast"([[VAL10]]#1)
// CHECK: [[VAL16:%.+]] = "tf.Div"([[VAL15]], [[VAL6]])
// CHECK: [[VAL17:%.+]] = "tf.Mul"([[VAL16]], [[VAL5]])
// CHECK: [[VAL18:%.+]] = "tf.Cast"([[VAL17]])
// CHECK: [[VAL19:%.+]] = "tf.Reshape"([[VAL18]], [[VAL1]])
// CHECK: [[VAL20:%.+]] = "tf.Mul"([[VAL19]], [[VAL10]]#2)
// CHECK: [[VAL21:%.+]] = "tf.Cast"([[VAL10]]#2)
// CHECK: [[VAL22:%.+]] = "tf.Div"([[VAL21]], [[VAL6]])
// CHECK: [[VAL23:%.+]] = "tf.Mul"([[VAL22]], [[VAL5]])
// CHECK: [[VAL24:%.+]] = "tf.Cast"([[VAL23]])
// CHECK: [[VAL25:%.+]] = "tf.Reshape"([[VAL24]], [[VAL4]])
// CHECK: [[VAL26:%.+]] = "tf.AddV2"([[VAL20]], [[VAL25]])
// CHECK: [[VAL27:%.+]] = "tf.Reshape"([[VAL26]], [[VAL2]])
// CHECK: [[VAL28:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL7]])
// CHECK: [[VAL29:%.+]] = "tf.ExpandDims"([[VAL11]], [[VAL7]])
// CHECK: [[VAL30:%.+]] = "tf.ExpandDims"([[VAL10]]#3, [[VAL7]])
// CHECK: [[VAL31:%.+]] = "tf.ConcatV2"([[VAL28]], [[VAL29]], [[VAL30]], [[VAL7]])
// CHECK: [[VAL32:%.+]] = "tf.Reshape"(%arg0, [[VAL31]])
// CHECK: [[VAL33:%.+]] = "tf.GatherV2"([[VAL32]], [[VAL27]], [[VAL0]]) {batch_dims = 0 : i64}
// CHECK: [[VAL34:%.+]] = "tf.Reshape"([[VAL33]], [[VAL14]])
// CHECK: return [[VAL34]]
%resize = "tf.ResizeNearestNeighbor"(%arg0, %shape) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x3x3x1xi32>
return %resize: tensor<1x3x3x1xi32>
}
// CHECK-LABEL: func @imag_resize_nearest_full_dyn
func @imag_resize_nearest_full_dyn(%arg0: tensor<1x?x?x1xi32>, %arg1: tensor<2xi32>) -> tensor<1x?x?x1xi32> {
// CHECK: [[VAL0:%.+]] = "tf.Const"() {value = dense<1> : tensor<i32>}
// CHECK: [[VAL1:%.+]] = "tf.Const"() {value = dense<0.000000e+00> : tensor<f32>}
// CHECK: [[VAL2:%.+]] = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>}
// CHECK: [[VAL3:%.+]] = "tf.Const"() {value = dense<1> : tensor<1xi32>}
// CHECK: [[VAL4:%.+]] = "tf.Const"() {value = dense<1> : tensor<1xi64>}
// CHECK: [[VAL5:%.+]] = "tf.Const"() {value = dense<0> : tensor<i64>}
// CHECK: [[VAL6:%.+]] = "tf.Shape"(%arg0)
// CHECK: [[VAL7:%.+]] = "tf.Cast"([[VAL6]])
// CHECK: [[VAL8:%.+]]:4 = "tf.Unpack"([[VAL7]]) {axis = 0 : i64}
// CHECK: [[VAL9:%.+]] = "tf.Mul"([[VAL8]]#1, [[VAL8]]#2)
// CHECK: [[VAL10:%.+]]:2 = "tf.Unpack"(%arg1) {axis = 0 : i64}
// CHECK: [[VAL11:%.+]] = "tf.Mul"([[VAL10]]#0, [[VAL10]]#1)
// CHECK: [[VAL12:%.+]] = "tf.ExpandDims"([[VAL8]]#0, [[VAL5]])
// CHECK: [[VAL13:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL5]])
// CHECK: [[VAL14:%.+]] = "tf.ExpandDims"([[VAL10]]#1, [[VAL5]])
// CHECK: [[VAL15:%.+]] = "tf.ExpandDims"([[VAL8]]#3, [[VAL5]])
// CHECK: [[VAL16:%.+]] = "tf.ConcatV2"([[VAL12]], [[VAL13]], [[VAL14]], [[VAL15]], [[VAL5]])
// CHECK: [[VAL17:%.+]] = "tf.Cast"([[VAL8]]#1)
// CHECK: [[VAL18:%.+]] = "tf.Cast"([[VAL10]]#0)
// CHECK: [[VAL19:%.+]] = "tf.Div"([[VAL17]], [[VAL18]])
// CHECK: [[VAL20:%.+]] = "tf.Range"([[VAL1]], [[VAL18]], [[VAL2]])
// CHECK: [[VAL21:%.+]] = "tf.Mul"([[VAL20]], [[VAL19]])
// CHECK: [[VAL22:%.+]] = "tf.Cast"([[VAL21]])
// CHECK: [[VAL23:%.+]] = "tf.ExpandDims"([[VAL10]]#0, [[VAL5]])
// CHECK: [[VAL24:%.+]] = "tf.ConcatV2"([[VAL23]], [[VAL3]], [[VAL5]])
// CHECK: [[VAL25:%.+]] = "tf.Reshape"([[VAL22]], [[VAL24]])
// CHECK: [[VAL26:%.+]] = "tf.Mul"([[VAL25]], [[VAL8]]#2)
// CHECK: [[VAL27:%.+]] = "tf.Cast"([[VAL8]]#2)
// CHECK: [[VAL28:%.+]] = "tf.Cast"([[VAL10]]#1)
// CHECK: [[VAL29:%.+]] = "tf.Div"([[VAL27]], [[VAL28]])
// CHECK: [[VAL30:%.+]] = "tf.Range"([[VAL1]], [[VAL28]], [[VAL2]])
// CHECK: [[VAL31:%.+]] = "tf.Mul"([[VAL30]], [[VAL29]])
// CHECK: [[VAL32:%.+]] = "tf.Cast"([[VAL31]])
// CHECK: [[VAL33:%.+]] = "tf.ExpandDims"([[VAL10]]#1, [[VAL5]])
// CHECK: [[VAL34:%.+]] = "tf.ConcatV2"([[VAL3]], [[VAL33]], [[VAL5]])
// CHECK: [[VAL35:%.+]] = "tf.Reshape"([[VAL32]], [[VAL34]])
// CHECK: [[VAL36:%.+]] = "tf.AddV2"([[VAL26]], [[VAL35]])
// CHECK: [[VAL37:%.+]] = "tf.Reshape"([[VAL11]], [[VAL4]])
// CHECK: [[VAL38:%.+]] = "tf.Reshape"([[VAL36]], [[VAL37]])
// CHECK: [[VAL39:%.+]] = "tf.ExpandDims"([[VAL8]]#0, [[VAL5]])
// CHECK: [[VAL40:%.+]] = "tf.ExpandDims"([[VAL9]], [[VAL5]])
// CHECK: [[VAL41:%.+]] = "tf.ExpandDims"([[VAL8]]#3, [[VAL5]])
// CHECK: [[VAL42:%.+]] = "tf.ConcatV2"([[VAL39]], [[VAL40]], [[VAL41]], [[VAL5]])
// CHECK: [[VAL43:%.+]] = "tf.Reshape"(%arg0, [[VAL42]])
// CHECK: [[VAL44:%.+]] = "tf.GatherV2"([[VAL43]], [[VAL38]], [[VAL0]]) {batch_dims = 0 : i64}
// CHECK: [[VAL45:%.+]] = "tf.Reshape"([[VAL44]], [[VAL16]])
// CHECK: return [[VAL45]]
%resize = "tf.ResizeNearestNeighbor"(%arg0, %arg1) {align_corners = false, device = "", half_pixel_centers = false} : (tensor<1x?x?x1xi32>, tensor<2xi32>) -> tensor<1x?x?x1xi32>
return %resize: tensor<1x?x?x1xi32>
}

View File

@ -1,7 +1,7 @@
// RUN: tf-opt %s -tf-mark-ops-for-outside-compilation | FILECHECK_OPTS="" FileCheck %s
// CHECK-LABEL: func @unsupported_op_no_soft_placement
func @unsupported_op_no_soft_placement() -> tensor<i32> {
// CHECK-LABEL: func @unsupported_op_missing_soft_placement_attribute
func @unsupported_op_missing_soft_placement_attribute() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
// CHECK: "tf.UnsupportedOp"
// CHECK-NOT: _xla_outside_compilation
@ -28,6 +28,24 @@ func @unsupported_op_soft_placement_false() -> tensor<i32> {
return %0 : tensor<i32>
}
// CHECK-LABEL: func @assert_op_string_operand
func @assert_op_string_operand(%arg0: tensor<!tf.string>) -> tensor<i32> {
%0 = "tf_device.cluster"() ( {
// CHECK: "tf.Assert"
// CHECK-NOT: _xla_outside_compilation
// CHECK: "tf.UnsupportedOp"
// CHECK-SAME: _xla_outside_compilation
// CHECK: "tf.Identity"
// CHECK-NOT: _xla_outside_compilation
%t = constant dense<true> : tensor<i1>
"tf.Assert"(%t, %arg0) {summarize = 3} : (tensor<i1>, tensor<!tf.string>) -> ()
%1 = "tf.UnsupportedOp"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%2 = "tf.Identity"(%1) : (tensor<i32>) -> tensor<i32>
tf_device.return %2 : tensor<i32>
}) {allow_soft_placement = true, num_cores_per_replica = 1, topology = "", device_assignment = []} : () -> tensor<i32>
return %0 : tensor<i32>
}
// CHECK-LABEL: func @unsupported_op
func @unsupported_op() -> tensor<i32> {
%0 = "tf_device.cluster"() ( {

View File

@ -2,8 +2,8 @@
func @main(%arg0: tensor<i32>, %arg1: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) {
%0:2 = tf_executor.graph {
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
%outputs_2:2, %control_3 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = false, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatefulWhile")
%outputs_4:2, %control_5 = tf_executor.island wraps "tf.While"(%arg0, %arg1) {body = @body, cond = @cond, is_stateless = true, output_shapes = [#tf.shape<>, #tf.shape<5>]} : (tensor<i32>, tensor<5xf32>) -> (tensor<i32>, tensor<5xf32>) loc("StatelessWhile")
tf_executor.fetch %outputs_2#1, %outputs_4#1 : tensor<5xf32>, tensor<5xf32>
}
return %0#0, %0#1 : tensor<5xf32>, tensor<5xf32>

View File

@ -299,7 +299,7 @@ func @main(%arg0: tensor<i32>) -> tensor<2xf32> {
%2 = "tf.PartitionedCall"(%0) {config = "", config_proto = "", executor_type = "", f = @callee} : (tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32>
return %2 : tensor<2xf32>
}
func @callee(%arg0: tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32> attributes {sym_visibility = "private"} {
func private @callee(%arg0: tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32> {
%0 = "tf.ReadVariableOp"(%arg0) : (tensor<!tf.resource<tensor<2xf32>>>) -> tensor<2xf32>
return %0 : tensor<2xf32>
}

View File

@ -1,9 +1,9 @@
// RUN: tf-opt %s -tf-region-control-flow-to-functional -split-input-file | FileCheck %s
// Simple IfRegion
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Neg"
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Abs"
func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// CHECK: "tf.If"
@ -24,9 +24,9 @@ func @testSimple(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
// -----
// Use if condition inside the regions
// CHECK: func @tf.IfRegion_else(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
// CHECK-NEXT: "tf.Select"(%arg0, %arg2, %arg3)
// CHECK: func @tf.IfRegion_then(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<i1>, %arg1: tensor<2xf32>, %arg2: tensor<2xf32>, %arg3: tensor<2xf32>) -> tensor<2xf32>
// CHECK-NEXT: "tf.Select"(%arg0, %arg1, %arg2)
func @testIfCondition(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
%0 = "tf.Add"(%arg1, %arg1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
@ -48,9 +48,9 @@ func @testIfCondition(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32>
// Constant sinking for IfRegion
// CHECK: func @tf.IfRegion_else() -> tensor<2xf32>
// CHECK: func private @tf.IfRegion_else() -> tensor<2xf32>
// CHECK-NEXT: constant dense<1.0
// CHECK: func @tf.IfRegion_then() -> tensor<2xf32>
// CHECK: func private @tf.IfRegion_then() -> tensor<2xf32>
// CHECK-NEXT: constant dense<0.0
func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
%cst_zero = constant dense<0.0> : tensor<2xf32>
@ -67,18 +67,18 @@ func @testIfConstant(%arg0: tensor<i1>) -> tensor<2xf32> {
// -----
// Nested IfRegions
// CHECK: func @tf.IfRegion1_else
// CHECK: func private @tf.IfRegion1_else
// CHECK-NEXT: "tf.Acos"
// CHECK-NEXT: "tf.Abs"
// CHECK: func @tf.IfRegion1_then
// CHECK: func private @tf.IfRegion1_then
// CHECK-NEXT: "tf.LogicalNot"
// CHECK-NEXT: "tf.Asin"
// CHECK-NEXT: "tf.If"({{.+}}) {else_branch = @tf.IfRegion_else, {{.+}} then_branch = @tf.IfRegion_then}
// CHECK: func @tf.IfRegion_else
// CHECK: func private @tf.IfRegion_else
// CHECK-NEXT: "tf.Neg"
// CHECK: func @tf.IfRegion_then
// CHECK: func private @tf.IfRegion_then
// CHECK-NEXT: "tf.Abs"
func @testNested(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
@ -169,10 +169,10 @@ func @testIf2Result(%arg0: tensor<i1>, %arg1: tensor<2xf32>) -> tensor<2xf32> {
// -----
// No inputs, some outputs for IfRegion
// CHECK: func @tf.IfRegion_else() -> tensor<2xf32>
// CHECK: func private @tf.IfRegion_else() -> tensor<2xf32>
// CHECK-NEXT: constant dense<1.000000e+00>
// CHECK-NEXT: "tf.Neg"
// CHECK: func @tf.IfRegion_then() -> tensor<2xf32>
// CHECK: func private @tf.IfRegion_then() -> tensor<2xf32>
// CHECK-NEXT: constant dense<0.000000e+00>
// CHECK-NEXT: "tf.Abs"
func @testSimple(%arg0: tensor<i1>) -> tensor<2xf32> {
@ -193,9 +193,9 @@ func @testSimple(%arg0: tensor<i1>) -> tensor<2xf32> {
// No outputs, some inputs for IfRegion
//
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>)
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<*xf32>)
// CHECK-NEXT: "tf.Neg"
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>)
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<*xf32>)
// CHECK-NEXT: "tf.Abs"
func @printer(tensor<*xf32>) -> ()
func @testNoOutputs(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> () {
@ -214,9 +214,9 @@ func @testNoOutputs(%arg0: tensor<i1>, %arg1: tensor<*xf32>) -> () {
// -----
// Check ToBool folding for IfRegion
// CHECK: func @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @tf.IfRegion_else(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Neg"
// CHECK: func @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK: func private @tf.IfRegion_then(%arg0: tensor<*xf32>) -> tensor<*xf32>
// CHECK-NEXT: "tf.Abs"
// CHECK-LABEL: @testToBoolFold
func @testToBoolFold(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32> {
@ -237,11 +237,11 @@ func @testToBoolFold(%arg0: tensor<i32>, %arg1: tensor<*xf32>) -> tensor<*xf32>
// -----
// Simple WhileRegion
// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"}
// CHECK: func private @tf.WhileRegion_body{{.+}}
// CHECK: "tf.Add"
// CHECK: constant dense<1>
// CHECK: "tf.Sub"
// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"}
// CHECK:func private @tf.WhileRegion_cond{{.+}}
// CHECK: constant dense<0>
// CHECK: "tf.NotEqual"
// CHECK-LABEL: testValidWhileRegion
@ -275,11 +275,11 @@ func @testValidWhileRegion(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor
// -----
// WhileRegion with type mismatch
// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"}
// CHECK: func private @tf.WhileRegion_body{{.+}}
// CHECK: "tf.Add"
// CHECK: constant dense<1>
// CHECK: "tf.Sub"
// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"}
// CHECK:func private @tf.WhileRegion_cond{{.+}}
// CHECK: constant dense<0>
// CHECK: "tf.NotEqual"
// CHECK-LABEL: testWhileRegionTypeMismatch
@ -309,11 +309,11 @@ func @testWhileRegionTypeMismatch(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) ->
// -----
// WhileRegion with constant sinking
// CHECK: func @tf.WhileRegion_body{{.+}}{sym_visibility = "private"}
// CHECK: func private @tf.WhileRegion_body{{.+}}
// CHECK: constant dense<1>
// CHECK: "tf.Add"
// CHECK: "tf.Sub"
// CHECK:func @tf.WhileRegion_cond{{.+}}{sym_visibility = "private"}
// CHECK:func private @tf.WhileRegion_cond{{.+}}
// CHECK: constant dense<0>
// CHECK: "tf.NotEqual"
// CHECK-LABEL: testWhileRegionConstantSink
@ -342,12 +342,12 @@ func @testWhileRegionConstantSink(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) ->
// -----
// WhileRegion with implicitly captured extern value in cond
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: "tf.Add"
// CHECK: constant dense<1>
// CHECK: "tf.Sub"
// CHECK: return %{{.+}}, %{{.+}}, %arg2 : tensor<*xf32>, tensor<i32>, tensor<i32>
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: "tf.NotEqual"(%arg1, %arg2)
// CHECK-LABEL: testWhileRegionExternInCond
func @testWhileRegionExternInCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
@ -376,12 +376,12 @@ func @testWhileRegionExternInCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %a
// -----
// WhileRegion with implicitly captured extern value in body
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: %0 = "tf.Add"(%arg0, %arg0)
// CHECK: %1 = "tf.Sub"(%arg1, %arg2)
// CHECK: return %0, %1, %arg2
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: constant dense<0>
// CHECK: "tf.NotEqual"
@ -412,9 +412,9 @@ func @testWhileRegionExternInBody(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %a
// -----
// WhileRegion with implicitly captured extern value in cond and body
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
// CHECK: return %{{.+}}, %{{.+}}, %arg2, %arg3
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>, %arg3: tensor<i32>)
// CHECK-LABEL: testWhileRegionExternInBodyAndCond
func @testWhileRegionExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
%cst = constant dense<4> : tensor<i32>
@ -443,9 +443,9 @@ func @testWhileRegionExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i
// -----
// WhileRegion with same value implicitly captured in cond and body
// CHECK: func @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: func private @tf.WhileRegion_body(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: return %{{.+}}, %{{.+}}, %arg2
// CHECK: func @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK: func private @tf.WhileRegion_cond(%arg0: tensor<*xf32>, %arg1: tensor<i32>, %arg2: tensor<i32>)
// CHECK-LABEL: testWhileRegionSameExternInBodyAndCond
func @testWhileRegionSameExternInBodyAndCond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>, %arg2 : tensor<i32>) -> tensor<*xf32> {
%cst = constant dense<4> : tensor<i32>
@ -559,9 +559,9 @@ func @testWhileRegionTrivialMultipleCasts(%arg0 : tensor<*xf32>, %arg1 : tensor<
// -----
// Almost trivially transformable with extern values
// CHECK: func @tf.WhileRegion_body
// CHECK: func private @tf.WhileRegion_body
// CHECK: call @while_body
// CHECK: @tf.WhileRegion_cond
// CHECK: func private @tf.WhileRegion_cond
// CHECK: call @while_cond
// CHECK-LABEL: testWhileRegionExtern
func @while_cond(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tensor<i1>
@ -589,9 +589,9 @@ func @testWhileRegionExtern(%arg0 : tensor<*xf32>, %arg1 : tensor<i32>) -> tenso
// -----
// Almost trivially transformable, mismatching block arguments
// CHECK: func @tf.WhileRegion_body
// CHECK: func private @tf.WhileRegion_body
// CHECK: call @while_body
// CHECK: @tf.WhileRegion_cond
// CHECK: func private @tf.WhileRegion_cond
// CHECK: call @while_cond
// CHECK-LABEL: testWhileRegionBlockArgMismatch
func @while_cond(%arg0 : tensor<i32>, %arg1 : tensor<*xf32>) -> tensor<i1>

View File

@ -17,8 +17,8 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 12 : i32, p
return %1 : tensor<f32>
}
// CHECK-NOT: func @callee
func @callee(%arg0: tensor<!tf.resource>) -> tensor<*xf32> attributes {sym_visibility = "private", tf.signature.is_stateful} {
// CHECK-NOT: func private @callee
func private @callee(%arg0: tensor<!tf.resource>) -> tensor<*xf32> attributes {tf.signature.is_stateful} {
%0 = "tf.ReadVariableOp"(%arg0) {device = ""} : (tensor<!tf.resource>) -> tensor<*xf32>
return %0 : tensor<*xf32>
}

View File

@ -644,7 +644,7 @@ func @callee(%arg0: tensor<f32>, %arg1: tensor<*x!tf.resource<tensor<f32>>>, %ar
%2 = "tf.AddV2"(%1, %arg2) : (tensor<f32>, tensor<f32>) -> tensor<f32>
return %2 : tensor<f32>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK: func private @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD0:.*]] = "tf.AddV2"(%[[A1]], %[[A0]])
// CHECK-NEXT: %[[ADD1:.*]] = "tf.AddV2"(%[[ADD0]], %[[A2]])
// CHECK-NEXT: return %[[ADD1]]
@ -691,7 +691,7 @@ func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>, %arg1: tensor<*x!tf.res
"tf.AssignVariableOp"(%arg0, %1) {dtype = i32} : (tensor<*x!tf.resource<tensor<f32>>>, tensor<f32>) -> ()
return %arg0 : tensor<*x!tf.resource<tensor<f32>>>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK: func private @callee_resource_lifted(%[[A0:.*]]: tensor<f32>, %[[A1:.*]]: tensor<f32>, %[[A2:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: %[[ADD:.*]] = "tf.AddV2"(%[[A1]], %[[A2]])
// CHECK-NEXT: return %[[ADD]]
@ -743,7 +743,7 @@ func @callee(%arg0: tensor<*x!tf.resource<tensor<f32>>>) -> tensor<f32> {
return %1 : tensor<f32>
}
// CHECK: func @callee_resource_lifted(%[[A0:.*]]: tensor<f32>) -> tensor<f32>
// CHECK: func private @callee_resource_lifted(%[[A0:.*]]: tensor<f32>) -> tensor<f32>
// CHECK-NEXT: return %[[A0]]
// -----
@ -1249,3 +1249,40 @@ func @callee(%arg0: !tf_res) -> tensor<i1> {
// CHECK-NEXT: return [[TRUE]] :
return %0 : tensor<i1>
}
// -----
// Tests passthrough tf.Cast ops are removed.
!tf_res = type tensor<*x!tf.resource<tensor<f32>>>
// CHECK-LABEL: func @tpu_computation
func @tpu_computation(%arg0: !tf_res) {
"tf_device.cluster"() ( {
%0 = "tf.While"(%arg0) {body = @while_body, cond = @while_cond, is_stateless = false} : (!tf_res) -> !tf_res
%1 = "tf.WhileRegion"(%arg0) ( {
^cond(%carg0: !tf_res):
%2 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
"tf.Yield"(%2) : (tensor<i1>) -> ()
}, {
^body(%barg0: !tf_res):
// CHECK-NOT: tf.Cast
%2 = "tf.Cast"(%barg0) : (!tf_res) -> !tf_res
"tf.Yield"(%2) : (!tf_res) -> ()
}) {is_stateless = false} : (!tf_res) -> !tf_res
tf_device.return
}) {} : () -> ()
return
}
func @while_cond(%arg0: !tf_res) -> tensor<i1> {
%0 = "tf.Const"() {value = dense<true> : tensor<i1>} : () -> tensor<i1>
return %0 : tensor<i1>
}
// CHECK-LABEL: func @while_body
func @while_body(%arg0: !tf_res) -> !tf_res {
// CHECK-NOT: tf.Cast
%0 = "tf.Cast"(%arg0) : (!tf_res) -> !tf_res
return %0 : !tf_res
}

View File

@ -439,16 +439,16 @@ module attributes {tf.versions = {bad_consumers = [], min_consumer = 0 : i32, pr
return %arg0 : tensor<2xi32>
}
// Test not updating call site if a std.call is used.
// Test iteratively updating call site if a std.call is used.
// CHECK-LABEL: func @call_partitioned_call2(
// CHECK-SAME: -> tensor<*xi32>
// CHECK-SAME: -> tensor<1xi32>
func @call_partitioned_call2() -> tensor<*xi32> {
// CHECK: () -> tensor<*xi32>
// CHECK: () -> tensor<1xi32>
%0 = call @partitioned_called_func2() : () -> tensor<*xi32>
return %0 : tensor<*xi32>
}
// CHECK-LABEL: func @partitioned_called_func2(
// CHECK-SAME: -> tensor<*xi32>
// CHECK-SAME: -> tensor<1xi32>
func @partitioned_called_func2() -> (tensor<*xi32>) {
%0 = "tf.Const"() {value = dense<-1> : tensor<1xi32>} : () -> tensor<1xi32>
%1 = tensor_cast %0 : tensor<1xi32> to tensor<*xi32>

View File

@ -287,14 +287,14 @@ func @main(%arg0: tensor<i1>) -> () {
}
// CHECK: func @callee(%[[AARG0:.*]]: tensor<!tf.resource>, %[[AARG1:.*]]: tensor<i1>) -> tensor<!tf.resource>
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> attributes {sym_visibility = "public"} {
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> {
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
// CHECK: tf.StackPushV2"
%push = "tf.StackPushV2"(%arg0, %elem) {swap_memory = false} : (tensor<!tf.resource>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @callee_stack_decomposed(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
// CHECK: func private @callee_stack_decomposed(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
// CHECK-NOT: "tf.StackPushV2"
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
// CHECK: "tf.AssignVariableOp"(%[[TARG0:.*]], %[[UPDATE]])
@ -326,8 +326,8 @@ func @main(%arg0: tensor<i1>) -> () {
return
}
// CHECK: func @callee(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
func @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> attributes {sym_visibility = "private"} {
// CHECK: func private @callee(%[[ARG0:.*]]: tensor<!tf.resource<tensor<10xf32>>>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<!tf.resource<tensor<1xi32>>>)
func private @callee(%arg0: tensor<!tf.resource>, %arg1: tensor<i1>) -> tensor<!tf.resource> {
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
// CHECK-NOT: "tf.StackPushV2"
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
@ -348,7 +348,7 @@ func @main() -> () {
return
}
// CHECK: func @callee()
func @callee() -> () attributes {sym_visibility = "public"} {
func @callee() -> () {
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NOT: tf.Stack
%stack = "tf.StackV2"(%max_size) {elem_type = f32, stack_name = "s"} : (tensor<i32>) -> tensor<!tf.resource>

View File

@ -432,7 +432,7 @@ func @main() -> () {
}
// CHECK-LABEL: func @callee
// CHECK-SAME: (%[[OCARG0:.*]]: tensor<!tf.resource>) -> tensor<!tf.resource>
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sym_visibility = "public"} {
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
%const1 = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>
%elem = "tf._SomeOp"() : () -> tensor<3xf32>
%flow = "tf.Const"() {value = dense<1.0> : tensor<f32>} : () -> tensor<f32>
@ -442,7 +442,7 @@ func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sy
%gwrite2 = "tf.TensorArrayWriteV3"(%grad2#0, %const1, %elem, %grad2#1) : (tensor<!tf.resource>, tensor<i32>, tensor<3xf32>, tensor<f32>) -> tensor<f32>
return %arg0 : tensor<!tf.resource>
}
// CHECK: func @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
// CHECK: func private @callee_tensorarray_decomposed(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
@ -480,8 +480,8 @@ func @main() -> () {
%read = "tf.TensorArrayReadV3"(%call2, %index, %ta#1) : (tensor<!tf.resource>, tensor<i32>, tensor<f32>) -> tensor<3xf32>
return
}
// CHECK: func @callee(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sym_visibility = "private"} {
// CHECK: func private @callee(%[[CARG0:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG1:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[CARG2:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>)
func private @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> {
// CHECK: %[[READ1:.*]] = "tf.ReadVariableOp"(%[[CARG1]]) : (tensor<!tf.resource<tensor<5x3xf32>>>) -> tensor<5x3xf32>
// CHECK: %[[UPDATE1:.*]] = "tf.XlaDynamicUpdateSlice"(%[[READ1]],
// CHECK: "tf.AssignVariableOp"(%[[CARG1]], %[[UPDATE1]])
@ -508,8 +508,8 @@ func @main() -> () {
%call = "tf.PartitionedCall"() {f = @callee, config = "", config_proto = "", executor_type = ""} : () -> tensor<i32>
return
}
// CHECK: func @callee() -> tensor<i32>
func @callee() -> tensor<i32> attributes {sym_visibility = "public"} {
// CHECK: func private @callee() -> tensor<i32>
func @callee() -> tensor<i32> {
%size = "tf.Const"() {value = dense<5> : tensor<i32>} : () -> tensor<i32>
// CHECK: "tf.MlirLocalVarOp"() : () -> tensor<!tf.resource<tensor<5xf32>>>
// CHECK: "tf.AssignVariableOp"
@ -567,7 +567,7 @@ func @main() -> () {
return
}
// CHECK-LABEL: func @callee
// CHECK-LABEL: func private @callee
// CHECK-SAME: %[[VAR:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>, %[[GVAR:.*]]: tensor<!tf.resource<tensor<5x3xf32>>>
func @callee(%arg0: tensor<!tf.resource>) -> tensor<!tf.resource> attributes {sym_visibility = "private"} {
%index = "tf.Const"() {value = dense<1> : tensor<i32>} : () -> tensor<i32>

View File

@ -472,14 +472,14 @@ func @main(%arg0: tensor<i1>) -> () {
}
// CHECK: func @callee(%[[AARG0:.*]]: tensor<!tf.variant<tensor<f32>>>, %[[AARG1:.*]]: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>>
func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>> attributes {sym_visibility = "public"} {
func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>> {
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
// CHECK: "tf.TensorListPushBack"
%push = "tf.TensorListPushBack"(%arg0, %elem) : (tensor<!tf.variant<tensor<f32>>>, tensor<f32>) -> tensor<!tf.variant<tensor<f32>>>
return %push : tensor<!tf.variant<tensor<f32>>>
}
// CHECK: func @callee_tensorlist_decomposed(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
// CHECK: func private @callee_tensorlist_decomposed(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
// CHECK-NOT: "tf.TensorListPushBack"
// CHECK: %[[UPDATE:.*]] = "tf.XlaDynamicUpdateSlice"
// CHECK: %[[CONST1:.*]] = "tf.Const"() {value = dense<1> : tensor<1xi32>} : () -> tensor<1xi32>
@ -514,7 +514,7 @@ func @main(%arg0: tensor<i1>) -> () {
return
}
// CHECK: func @callee(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
// CHECK: func private @callee(%[[ARG0:.*]]: tensor<10xf32>, %[[ARG1:.*]]: tensor<i1>, %[[ARG2:.*]]: tensor<1xi32>) -> (tensor<10xf32>, tensor<1xi32>)
func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tensor<!tf.variant<tensor<f32>>> attributes {sym_visibility = "private"} {
%elem = "tf._SomeOp"(%arg1) : (tensor<i1>) -> tensor<f32>
@ -533,12 +533,12 @@ func @callee(%arg0: tensor<!tf.variant<tensor<f32>>>, %arg1: tensor<i1>) -> tens
// Tests PartitionedCall op with no signature change on callee.
// CHECK-LABEL: func @main
func @main() -> () {
func @main() {
"tf.PartitionedCall"() {f = @callee, config = "", config_proto = "", executor_type = ""} : () -> ()
return
}
// CHECK: func @callee()
func @callee() -> () attributes {sym_visibility = "public"} {
// CHECK: func private @callee()
func @callee() {
%elem_shape = "tf.Const"() {value = dense<> : tensor<0xi32>} : () -> tensor<0xi32>
%max_size = "tf.Const"() {value = dense<10> : tensor<i32>} : () -> tensor<i32>
// CHECK-NOT: tf.EmptyTensorList

View File

@ -62,7 +62,7 @@ class TestModule(tf.Module):
# CHECK-SAME: attributes{{.*}}tf_saved_model.exported_names = ["caller"]
# CHECK: "tf.StatefulPartitionedCall"{{.*}}f = @[[CALLEE_INTERNAL]]
#
# CHECK: func @[[CALLEE_INTERNAL]]
# CHECK: func private @[[CALLEE_INTERNAL]]
# CHECK-NOT: tf_saved_model.exported_names
@tf.function(input_signature=[tf.TensorSpec([], tf.float32)])

View File

@ -35,8 +35,8 @@ from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# CHECK-SAME: else_branch = @[[else]]
# CHECK-SAME: then_branch = @[[then]]
# CHECK: func @[[else]](
# CHECK: func @[[then]](
# CHECK: func private @[[else]](
# CHECK: func private @[[then]](
def Test():

View File

@ -26,11 +26,11 @@ import tempfile
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[init:.*]]]} : () -> ()
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset1:__tf_saved_model_asset1_.*]]"}
# CHECK: "tf_saved_model.asset"() {filename = {{.*}}, sym_name = "[[asset0:__tf_saved_model_asset0_.*]]"}
# CHECK: func [[init]]
# CHECK: func @[[init]]
# CHECK-SAME: [[ARG0:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset0]]}
# CHECK-SAME: [[ARG1:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[asset1]]}
# CHECK-NEXT: [[R0:%.*]] = "tf.HashTableV2"()

View File

@ -34,9 +34,9 @@ from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# CHECK-SAME: producer
# CHECK: "tf_saved_model.global_tensor"()
# CHECK: "tf_saved_model.session_initializer"() {initializer = [[init:@.*]]} : () -> ()
# CHECK: "tf_saved_model.session_initializer"() {initializers = [@[[init:.*]]]} : () -> ()
# CHECK: func [[init]]
# CHECK: func @[[init]]
# CHECK-NEXT: [[R5:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R6:%.*]] = "tf.Const"()
# CHECK-NEXT: [[R7:%.*]] = "tf.HashTableV2"()
@ -89,4 +89,4 @@ def Test():
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(Test)
common_v1.do_test(Test, canonicalize=True)

View File

@ -0,0 +1,80 @@
# Copyright 2019 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# RUN: %p/import_restore_v1 | FileCheck %s
# pylint: disable=missing-docstring,line-too-long
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import tensorflow.compat.v1 as tf
from tensorflow.compiler.mlir.tensorflow.tests.tf_saved_model import common_v1
# Verify that the tf.versions attribute exists. It is difficult to enforce
# contents, since the version numbers change over time. The conversion logic
# itself is verified in the common graphdef converter, so here just assert
# it is being invoked.
# CHECK: module
# CHECK-SAME: tf.versions
# CHECK-SAME: bad_consumers
# CHECK-SAME: min_consumer
# CHECK-SAME: producer
# CHECK: tf_saved_model.session_initializer
# CHECK-SAME: initializers = [@[[restore:.*]]]
# CHECK: "tf_saved_model.asset"()
# CHECK-SAME: {filename = [[filename:.*]], sym_name = "[[sym_name:.*]]"} : () -> ()
# CHECK: func @[[restore]](
# CHECK-SAME: [[variable_path:%.*]]: tensor<!tf.string> {tf_saved_model.bound_input = @[[sym_name]]}
# CHECK-SAME: tf_saved_model.exported_names = ["{{__tf_saved_model_session_initializer.*}}"]
# CHECK: [[v0:%.*]] = "tf.RestoreV2"([[variable_path]]
# CHECK: [[v1:%.*]] = "tf.Identity"([[v0]])
# CHECK: [[handle:%.*]] = "tf.VarHandleOp"
# CHECK-SAME: shared_name = [[shared_name:".*"]]
# CHECK: "tf.AssignVariableOp"([[handle]], [[v1]])
# CHECK: func {{@[a-zA-Z_0-9]+}}(
# CHECK-SAME: tf_saved_model.exported_names = ["key"]
# CHECK: tf.VarHandleOp
# CHECK-SAME: shared_name = [[shared_name]]
def Test():
x = tf.constant([[1.0], [1.0], [1.0]])
y = tf.compat.v1.get_variable(
name='y',
shape=(1, 3),
initializer=tf.random_normal_initializer(),
trainable=True)
r = tf.matmul(x, y)
tensor_info_x = tf.compat.v1.saved_model.utils.build_tensor_info(x)
tensor_info_r = tf.compat.v1.saved_model.utils.build_tensor_info(r)
return {
'key': (tf.compat.v1.saved_model.signature_def_utils.build_signature_def(
inputs={'x': tensor_info_x},
outputs={'r': tensor_info_r},
method_name='some_function'))
}, None, None
if __name__ == '__main__':
common_v1.set_tf_options()
common_v1.do_test(Test, use_lite=True)

View File

@ -4,7 +4,7 @@ module attributes {tf_saved_model.semantics} {
// CHECK: tf_saved_model.session_initializer
"tf_saved_model.session_initializer"() {
initializer = @init
initializers = [@init]
} : () -> ()
// CHECK: tf_saved_model.asset

View File

@ -277,7 +277,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function does not exist}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
}
// -----
@ -285,7 +285,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have no output}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
@ -298,7 +298,7 @@ module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
// expected-error@+1 {{there must be no more than one session_initializer op}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() -> tensor<1xf32> attributes {sym_visibility = "private"} {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return %0 : tensor<1xf32>
@ -336,7 +336,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function does not exist}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
}
// -----
@ -344,7 +344,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have no output}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
@ -356,9 +356,9 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
// expected-error@+1 {{there must be no more than one session_initializer op}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() -> (tensor<1xf32> {tf_saved_model.index_path = ["output"]})
attributes { tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"] } {
%0 = "tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
@ -371,7 +371,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should be exported}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() attributes {sym_visibility = "private"} {
return
}
@ -382,7 +382,7 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// expected-error@+1 {{the initializer function should have only one exported name}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() attributes { tf_saved_model.exported_names = ["a", "b"] } {
return
}

View File

@ -111,14 +111,14 @@ module attributes {tf_saved_model.semantics} {
return %val : tensor<f32>
}
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
// CHECK: func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
// CHECK: func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32>
@ -145,14 +145,14 @@ module attributes {tf_saved_model.semantics} {
return %val : tensor<f32>
}
// CHECK: func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
// CHECK: func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func private @f_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f_callee_callee} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
// CHECK: func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func private @f_callee_callee(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32>
@ -178,14 +178,14 @@ module attributes {tf_saved_model.semantics} {
}
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
// CHECK: func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @g} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
// CHECK: func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
// CHECK: func private @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func private @g(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%val = "tf.PartitionedCall"(%arg0) {config = "", config_proto = "", executor_type = "", f = @f} : (tensor<*x!tf.resource>) -> (tensor<f32>)
return %val : tensor<f32>
}
@ -211,8 +211,8 @@ module attributes {tf_saved_model.semantics} {
}
// CHECK: func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> attributes {sym_visibility = "private"} {
// CHECK: func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32>
func private @f(%arg0: tensor<*x!tf.resource>) -> tensor<f32> {
%c0 = "tf.Const"() { value = dense<1.0> : tensor<f32> } : () -> tensor<f32>
"tf.AssignAddVariableOp"(%arg0, %c0) : (tensor<*x!tf.resource>, tensor<f32>) -> ()
return %c0 : tensor<f32>

View File

@ -9,14 +9,14 @@ module attributes {tf_saved_model.semantics} {
module attributes {tf_saved_model.semantics} {
// Test case: No matching function for the given session initializer.
// expected-error@+1 {{'tf_saved_model.session_initializer' op the initializer function does not exist}}
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
}
// -----
module attributes {tf_saved_model.semantics} {
// Test case: Invalid multiple blocks in the initializer funcion.
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
// expected-error@+1 {{expects exactly one block in the MLIR function}}
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
br ^bb1
@ -32,7 +32,7 @@ module attributes {tf_saved_model.semantics} {
// CHECK: func @init()
// CHECK: tf.Const
// CHECK: return
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
"tf.Const"() {value = dense<[1.0]> : tensor<1xf32> } : () -> tensor<1xf32>
return
@ -48,7 +48,7 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
// CHECK-NOT: tf.Const
// CHECK-NOT: tf.AssignAddVariableOp
// CHECK: return
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<2x8xi32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "w"} : () -> tensor<*x!tf.resource<tensor<2xi32>>>
@ -69,7 +69,7 @@ module attributes {tf_saved_model.semantics, tf_saved_model.under_construction}
// CHECK-NOT: tf.Const
// CHECK-NOT: tf.AssignAddVariableOp
// CHECK: return
"tf_saved_model.session_initializer"() { initializer = @init } : () -> ()
"tf_saved_model.session_initializer"() { initializers = [@init] } : () -> ()
func @init() attributes {tf_saved_model.exported_names = ["__tf_saved_model_session_initializer"]} {
%0 = "tf.VarHandleOp"() {container = "c", shared_name = "v"} : () -> tensor<*x!tf.resource<tensor<2x8xi32>>>
%1 = "tf.VarHandleOp"() {container = "c", shared_name = "w"} : () -> tensor<*x!tf.resource<tensor<2xi32>>>

View File

@ -6,12 +6,10 @@
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*x!tf.resource<tensor<64xf32>>>
// CHECK-SAME: %[[ARG_2:.*]]: tensor<*x!tf.resource<tensor<16xf32>>>
// CHECK-SAME: %[[ARG_3:.*]]: tensor<!tf.string>
func @merge_same_device_variables(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
%arg2: tensor<*x!tf.resource<tensor<16xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"},
%arg3: tensor<!tf.string>) {
%arg2: tensor<*x!tf.resource<tensor<16xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:CPU:0"}) {
// CHECK-NEXT: %[[ID_0:.*]] = "tf.IdentityN"(%[[ARG_0]])
%id0 = "tf.IdentityN"(%arg0) {device = "/job:localhost/replica:0/task:0/device:TPU:0"}
: (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<*x!tf.resource<tensor<32xf32>>>
@ -19,15 +17,27 @@ func @merge_same_device_variables(
%read0 = "tf.ReadVariableOp"(%id0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
%read2 = "tf.ReadVariableOp"(%arg2) : (tensor<*x!tf.resource<tensor<16xf32>>>) -> tensor<16xf32>
// CHECK-NEXT: %[[EXE:.*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[ARG_3]])
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
%compile:2 = "tf_device.launch"() ( {
// CHECK: tf._TPUCompileMlir
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main(%arg0: tensor<32xf32> {tf.aliasing_output = 0 : i64},
// CHECK-SAME: %arg1: tensor<64xf32>, %arg2: tensor<16xf32>)
%0:2 = "tf._TPUCompileMlir"() {
metadata = "",
mlir_module = "module attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor<32xf32>, %arg1: tensor<64xf32>, %arg2: tensor<16xf32>) -> (tensor<32xf32>, tensor<16xf32>) {\0A %0:2 = \22tf.A\22(%arg0, %arg1, %arg2) : (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>) -> (tensor<32xf32>, tensor<16xf32>)\0A return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32>\0A }\0A}"
} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %0#0, %0#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
// CHECK: %[[EXE:.*]] = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[ID_0]], %[[ARG_1]], %[[READ_2]], %[[COMPILE]]#1)
// CHECK-SAME: device_var_reads_indices = [0, 1],
// CHECK-SAME: device_var_updates_indices = [0, -1]
%execute:2 = "tf_device.launch"() ( {
%0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %arg3) {
%0:2 = "tf.TPUExecute"(%read0, %read1, %read2, %compile#1) {
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<16xf32>],
Tresults = [tensor<32xf32>, tensor<16xf32>]}
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
: (tensor<32xf32>, tensor<64xf32>, tensor<16xf32>, tensor<2x!tf.string>) -> (tensor<32xf32>, tensor<16xf32>)
tf_device.return %0#0, %0#1 : tensor<32xf32>, tensor<16xf32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<16xf32>)
// CHECK-NEXT: tf_device.return
@ -44,26 +54,35 @@ func @merge_same_device_variables(
// Tests that the pass do not check devices for replicated region.
// CHECK-LABEL: func @merge_replicated_variables
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<!tf.string>,
// CHECK-SAME: %[[ARG_2:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>,
// CHECK-SAME: %[[ARG_3:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>, %[[ARG_1:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>,
// CHECK-SAME: %[[ARG_2:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
func @merge_replicated_variables(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg1: tensor<!tf.string>,
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg3: tensor<*x!tf.resource<tensor<32xf32>>>) {
%arg1: tensor<*x!tf.resource<tensor<32xf32>>>,
%arg2: tensor<*x!tf.resource<tensor<32xf32>>>) {
// CHECK-NEXT: %[[READ_0:.*]] = "tf.ReadVariableOp"(%[[ARG_0]])
%read0 = "tf.ReadVariableOp"(%arg0) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
// CHECK-NEXT: tf_device.replicate([%[[ARG_2]], %[[ARG_3]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>)
tf_device.replicate([%arg2, %arg3] as %r: tensor<*x!tf.resource<tensor<32xf32>>>) {n = 2 : i32} {
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
%compile:2 = "tf_device.launch"() ( {
// CHECK: tf._TPUCompileMlir
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32> {tf.aliasing_output = 0 : i64})
%0:2 = "tf._TPUCompileMlir"() {
metadata = "",
mlir_module = "module attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) -> (tensor<32xf32>) {\0A %0 = \22tf.A\22(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<32xf32>)\0A return %0 : tensor<32xf32>\0A }\0A}"
} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %0#0, %0#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
// CHECK: tf_device.replicate([%[[ARG_1]], %[[ARG_2]]] as %[[R_ARG:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>)
tf_device.replicate([%arg1, %arg2] as %r: tensor<*x!tf.resource<tensor<32xf32>>>) {n = 2 : i32} {
// CHECK-NEXT: "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[ARG_1]])
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[R_ARG]], %[[COMPILE]]#1)
// CHECK-SAME: device_var_reads_indices = [1],
// CHECK-SAME: device_var_updates_indices = [0]
%read1 = "tf.ReadVariableOp"(%r) : (tensor<*x!tf.resource<tensor<32xf32>>>) -> tensor<32xf32>
%execute = "tf_device.launch"() ( {
%0 = "tf.TPUExecute"(%read0, %read1, %arg1)
: (tensor<32xf32>, tensor<32xf32>, tensor<!tf.string>) -> tensor<32xf32>
%0 = "tf.TPUExecute"(%read0, %read1, %compile#1)
: (tensor<32xf32>, tensor<32xf32>, tensor<2x!tf.string>) -> tensor<32xf32>
tf_device.return %0 : tensor<32xf32>
}) {device = ""} : () -> tensor<32xf32>
// CHECK-NEXT: tf_device.return
@ -86,7 +105,6 @@ func @merge_replicated_variables(
// CHECK-SAME: %[[ARG_0:.*]]: tensor<*x!tf.resource<tensor<32xf32>>>
// CHECK-SAME: %[[ARG_1:.*]]: tensor<*x!tf.resource<tensor<64xf32>>>
// CHECK-SAME: %[[ARG_2:.*]]: tensor<32xf32>
// CHECK-SAME: %[[ARG_3:.*]]: tensor<!tf.string>
// CHECK-SAME: %[[ARG_4:.*]]: tensor<*x!tf.resource<tensor<8xf32>>>
// CHECK-SAME: %[[ARG_5:.*]]: tensor<*x!tf.resource<tensor<2xf32>>>
// CHECK-SAME: %[[ARG_6:.*]]: tensor<2xf32>
@ -94,7 +112,6 @@ func @interferencing_accesses(
%arg0: tensor<*x!tf.resource<tensor<32xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
%arg1: tensor<*x!tf.resource<tensor<64xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
%arg2: tensor<32xf32>,
%arg3: tensor<!tf.string>,
%arg4: tensor<*x!tf.resource<tensor<8xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
%arg5: tensor<*x!tf.resource<tensor<2xf32>>> {tf.device = "/job:localhost/replica:0/task:0/device:TPU:0"},
%arg6: tensor<2xf32>) -> (tensor<8xf32>) {
@ -108,15 +125,26 @@ func @interferencing_accesses(
"tf.AssignVariableOp"(%arg5, %arg6) : (tensor<*x!tf.resource<tensor<2xf32>>>, tensor<2xf32>) -> ()
%read1 = "tf.ReadVariableOp"(%arg1) : (tensor<*x!tf.resource<tensor<64xf32>>>) -> tensor<64xf32>
%read2 = "tf.ReadVariableOp"(%arg4) : (tensor<*x!tf.resource<tensor<8xf32>>>) -> tensor<8xf32>
// CHECK-NEXT: %[[EXE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[ARG_3]])
// CHECK: %[[COMPILE:.*]]:2 = "tf_device.launch"
%compile:2 = "tf_device.launch"() ( {
// CHECK: tf._TPUCompileMlir
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32> {tf.aliasing_output = 1 : i64})
%0:2 = "tf._TPUCompileMlir"() {
metadata = "",
mlir_module = "module attributes {tf.versions = {producer = 888 : i32}} {\0A func @main(%arg0: tensor<32xf32>, %arg1: tensor<32xf32>) -> (tensor<32xf32>) {\0A %0 = \22tf.A\22(%arg0, %arg1) : (tensor<32xf32>, tensor<32xf32>) -> (tensor<32xf32>)\0A return %0 : tensor<32xf32>\0A }\0A}"
} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
tf_device.return %0#0, %0#1 : tensor<!tf.string>, tensor<2x!tf.string>
}) {device = "/job:worker/replica:0/task:0/device:CPU:0"} : () -> (tensor<!tf.string>, tensor<2x!tf.string>)
// CHECK: %[[EXE:.*]]:2 = "tf_device.launch"
// CHECK-NEXT: "tf.TPUExecuteAndUpdateVariables"(%[[READ_0]], %[[ARG_1]], %[[ARG_4]], %[[READ_5]], %[[COMPILE]]#1)
// CHECK-SAME: device_var_reads_indices = [1, 2],
// CHECK-SAME: device_var_updates_indices = [1, -1]
%execute:3 = "tf_device.launch"() ( {
%0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %arg3) {
%0:3 = "tf.TPUExecute"(%read0, %read1, %read2, %read5, %compile#1) {
Targs = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>],
Tresults = [tensor<32xf32>, tensor<64xf32>, tensor<8xf32>]}
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<!tf.string>)
: (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>, tensor<2xf32>, tensor<2x!tf.string>)
-> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)
tf_device.return %0#0, %0#1, %0#2 : tensor<32xf32>, tensor<64xf32>, tensor<8xf32>
}) {device = "/job:localhost/replica:0/task:0/device:TPU:0"} : () -> (tensor<32xf32>, tensor<64xf32>, tensor<8xf32>)

View File

@ -859,7 +859,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func @nested_func
// CHECK-SAME: func private @nested_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: device = "/job:worker/replica:0/task:0/device:CPU:0"
@ -908,7 +908,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func @referenced_func
// CHECK-SAME: func private @referenced_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: "tf_device.launch"
@ -1007,7 +1007,7 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-COUNT-2: call @referenced_func
// CHECK-COUNT-1: func @referenced_func
// CHECK-COUNT-1: func private @referenced_func
// CHECK-SAME: tf.D
// CHECK-NOT: func = @tpu0_func
// CHECK: "tf_device.launch"
@ -1161,13 +1161,13 @@ module attributes {tf.versions = {producer = 888 : i32}, tf.devices = ["/job:wor
// CHECK-SAME: mlir_module
// CHECK-SAME: func @main
// CHECK-SAME: tf.B
// CHECK-SAME: func @referenced_func3
// CHECK-SAME: func private @referenced_func3
// CHECK-SAME: tf.I
// CHECK-SAME: func @referenced_func2
// CHECK-SAME: func private @referenced_func2
// CHECK-SAME: tf.H
// CHECK-SAME: func @referenced_func1
// CHECK-SAME: func private @referenced_func1
// CHECK-SAME: tf.G
// CHECK-SAME: func @referenced_func0
// CHECK-SAME: func private @referenced_func0
// CHECK-SAME: tf.F
// CHECK: "tf_device.launch"
// CHECK-NEXT: "tf.TPUCompileSucceededAssert"(%[[COMPILE_OUTPUT]]#0)

View File

@ -44,9 +44,9 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:CPU:0"
%10 = "tf.Identity"(%9) {device = ""} : (tensor<i1>) -> tensor<i1>
return %10 : tensor<i1>
}
// CHECK-LABEL: func @_func
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
// CHECK-LABEL: func private @_func
// CHECK-SAME: [[FUNCINPUT0:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT1:%.*]]: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[FUNCINPUT2:%.*]]: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, [[VAL_59:%.*]]: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<7x7x3x64xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<i64> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
%0 = "tf.Const"() {value = dense<1> : tensor<i64>} : () -> tensor<i64>
%1 = "tf.Const"() {value = dense<0> : tensor<1x1xi32>} : () -> tensor<1x1xi32>
%2 = "tf.Const"() {value = dense<[7, 7, 3, 64]> : tensor<4xi32>} : () -> tensor<4xi32>
@ -112,9 +112,9 @@ module attributes {tf.devices = {"/job:localhost/replica:0/task:0/device:COMPOSI
}
return
}
// CHECK-LABEL: func @_func
// CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
func @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) attributes {sym_visibility = "private"} {
// CHECK-LABEL: func private @_func
// CHECK-SAME: [[FUNCINPUT00:.*]]: tensor<2x112x112x12xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
func private @_func(%arg0: tensor<2x224x224x3xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg1: tensor<2x1xf32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg2: tensor<7x7x3x64xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg3: tensor<64x1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg4: tensor<1001xf32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg5: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg6: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg7: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, %arg8: tensor<f32> {mhlo.is_same_data_across_replicas, mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) -> (tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}, tensor<f32> {mhlo.sharding = "\08\01\1A\01\01\22\01\00"}) {
%0 = "tf.Const"() {value = dense<2.000000e+00> : tensor<f32>} : () -> tensor<f32>
%1 = "tf.Const"() {value = dense<1.000000e+00> : tensor<f32>} : () -> tensor<f32>
%2 = "tf.Const"() {value = dense<-1> : tensor<i32>} : () -> tensor<i32>

View File

@ -112,7 +112,8 @@ LogicalResult ConvertIfOp(IfOp if_op) {
LogicalResult ConvertWhileOp(WhileOp while_op) {
auto while_region = OpBuilder(while_op).create<TF::WhileRegionOp>(
while_op.getLoc(), while_op.getResultTypes(), while_op.input(),
while_op.is_stateless(), while_op.parallel_iterations());
while_op.output_shapes(), while_op.parallel_iterations(),
while_op.is_stateless());
CopyDeviceAndUnderscoredAttributes(while_op, while_region);
YieldOp cond_yield =

Some files were not shown because too many files have changed in this diff Show More