Merged commit includes the following changes:

277453541  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Fix import path.

--
277445856  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Separate `update_bazel_(linux|macos)` call from `kokoro_init_(linux|macos)` function. Modify all build scripts to run `update_bazel_*` right after `kokoro_init_*` call.

--
277440435  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 277217247.

277434196  by A. Unique TensorFlower<gardener@tensorflow.org>:

    TFLM: Move runtime tensor initialization upfront and allow tensor_info to be freed.

    This CL will save `sizeof(TensorInfo) * tensors_size` bytes in tensor_arena.

    The cascaded design of SimpleMemoryAllocator provides a save and intuitive abstraction. The Child allocator serves as a temporary allocator. Whatever allocated in the child will be freed once it goes out of scope. And it doesn't not affect allocation in the parent - parent allocator is locked down while child allocator is available.

--
277428110  by A. Unique TensorFlower<gardener@tensorflow.org>:
    Automated rollback of changelist 276610559.

277426570  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Exit gracefully when outside compilation is invoked outside TPUReplicateContext
    scope.

--
277423900  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Tensor tracer: adding ShapeN to list of ops that should not be traced.

--
277422316  by A. Unique TensorFlower<gardener@tensorflow.org>:

    [XLA] Don't try to inline dead instructions.

    Inlining previous instruction could make other call instructions
    dead. Don't inline those dead instructions.

--
277414458  by A. Unique TensorFlower<gardener@tensorflow.org>:

    Op documentation update.
    	update of g3doc/tfl_ops.md
    	update of g3doc/tf_ops.md

--

PiperOrigin-RevId: 277453541
This commit is contained in:
A. Unique TensorFlower 2019-10-30 00:56:13 -07:00 committed by TensorFlower Gardener
parent e9a3aa158a
commit 552d6a22f6
30 changed files with 5084 additions and 1195 deletions

File diff suppressed because it is too large Load Diff

File diff suppressed because it is too large Load Diff

View File

@ -104,6 +104,7 @@ cc_library(
"transforms/legalize_tf.cc",
],
deps = [
":convert_op_folder",
":hlo",
"//tensorflow/compiler/mlir/tensorflow",
"//tensorflow/compiler/mlir/tensorflow:lower_tf_lib",
@ -257,6 +258,7 @@ cc_library(
],
includes = ["include"],
deps = [
":convert_op_folder",
":hlo_ops_base_inc_gen",
":hlo_ops_inc_gen",
":xla_canonicalize_inc_gen",
@ -470,3 +472,12 @@ genrule(
" -o $@"),
tools = [":operator_writer_gen"],
)
cc_library(
name = "convert_op_folder",
srcs = ["convert_op_folder.cc"],
hdrs = ["convert_op_folder.h"],
deps = [
"@local_config_mlir//:IR",
],
)

View File

@ -0,0 +1,84 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// This file defines helpers useful when creating or manipulating lhlo/hlo.
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
namespace xla {
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type) {
auto old_type = getElementTypeOrSelf(elements);
size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
if (old_type.isa<mlir::FloatType>()) {
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = mlir::APInt(const llvm::APFloat&);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Float -> Float
return elements.mapValues(
new_type, llvm::function_ref<func_type>(
[&newFloatType](const llvm::APFloat& floatVal) {
llvm::APFloat newDouble(
mlir::FloatAttr::getValueAsDouble(floatVal));
bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven,
&loses_info);
return newDouble.bitcastToAPInt();
}));
}
// Float -> Int
return elements.mapValues(
new_type, llvm::function_ref<func_type>(
[&bit_width](const llvm::APFloat& floatVal) {
return llvm::APInt(
bit_width,
mlir::FloatAttr::getValueAsDouble(floatVal));
}));
}
// old_type is Integer
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = llvm::APInt(const llvm::APInt&);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Int -> Float
return elements.mapValues(
new_type, llvm::function_ref<func_type>([&newFloatType](
const llvm::APInt& intVal) {
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &loses_info);
return newDouble.bitcastToAPInt();
}));
}
// new_type is Integer
// Int -> Int
return elements.mapValues(
new_type,
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
return llvm::APInt(bit_width, intVal.getSExtValue());
}));
}
} // namespace xla

View File

@ -0,0 +1,31 @@
/* Copyright 2019 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#ifndef TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
#define TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
namespace xla {
// Converts the given elements attr to the specified elements type.
// Requires type of the elements and new_type to be either integer or float
// type.
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_CONVERT_OP_FOLDER_H_

View File

@ -19,6 +19,7 @@ limitations under the License.
#include "mlir/IR/Attributes.h" // TF:local_config_mlir
#include "mlir/IR/StandardTypes.h" // TF:local_config_mlir
#include "mlir/IR/TypeUtilities.h" // TF:local_config_mlir
#include "tensorflow/compiler/xla/literal.h"
namespace xla {
@ -79,4 +80,62 @@ mlir::DenseIntElementsAttr CreateDenseIntElementsAttrFromVector(
.cast<mlir::DenseIntElementsAttr>();
}
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type) {
auto old_type = getElementTypeOrSelf(elements);
size_t bit_width = new_type.isBF16() ? 64 : new_type.getIntOrFloatBitWidth();
if (old_type.isa<mlir::FloatType>()) {
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = mlir::APInt(const llvm::APFloat&);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Float -> Float
return elements.mapValues(
new_type, llvm::function_ref<func_type>(
[&newFloatType](const llvm::APFloat& floatVal) {
llvm::APFloat newDouble(
mlir::FloatAttr::getValueAsDouble(floatVal));
bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven,
&loses_info);
return newDouble.bitcastToAPInt();
}));
}
// Float -> Int
return elements.mapValues(
new_type, llvm::function_ref<func_type>(
[&bit_width](const llvm::APFloat& floatVal) {
return llvm::APInt(
bit_width,
mlir::FloatAttr::getValueAsDouble(floatVal));
}));
}
// old_type is Integer
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = llvm::APInt(const llvm::APInt&);
if (auto newFloatType = new_type.dyn_cast<mlir::FloatType>()) {
// Int -> Float
return elements.mapValues(
new_type, llvm::function_ref<func_type>([&newFloatType](
const llvm::APInt& intVal) {
llvm::APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
bool loses_info = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &loses_info);
return newDouble.bitcastToAPInt();
}));
}
// new_type is Integer
// Int -> Int
return elements.mapValues(
new_type,
llvm::function_ref<func_type>([&bit_width](const llvm::APInt& intVal) {
return llvm::APInt(bit_width, intVal.getSExtValue());
}));
}
} // namespace xla

View File

@ -77,6 +77,11 @@ StatusOr<mlir::Type> ConvertShapeToType(const Shape& shape,
return ConvertTensorShapeToType<TypeT>(shape, builder);
}
// Converts the given elements attr to the specified elements type.
// Requires type of the elements and new_type to be either integer or float
// type.
mlir::ElementsAttr ConvertElementsAttr(const mlir::ElementsAttr& elements,
mlir::Type new_type);
} // namespace xla
#endif // TENSORFLOW_COMPILER_MLIR_XLA_HLO_UTILS_H_

View File

@ -44,6 +44,7 @@ limitations under the License.
#include "mlir/IR/Types.h" // TF:local_config_mlir
#include "mlir/IR/Value.h" // TF:local_config_mlir
#include "mlir/Support/LogicalResult.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h.inc"
namespace mlir {
@ -246,69 +247,13 @@ void ConvertOp::build(Builder* builder, OperationState& result, Value* operand,
build(builder, result, result_ty, operand);
}
namespace {
// Converts the values of an ElementsAttr into the corresponding type.
ElementsAttr ConvertElements(const ElementsAttr& elements, Type newType) {
auto oldType = getElementTypeOrSelf(elements);
size_t bitWidth = newType.isBF16() ? 64 : newType.getIntOrFloatBitWidth();
if (oldType.isa<FloatType>()) {
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = APInt(const APFloat&);
if (auto newFloatType = newType.dyn_cast<FloatType>()) {
// Float -> Float
return elements.mapValues(
newType, llvm::function_ref<func_type>([&newFloatType](
const APFloat& floatVal) {
APFloat newDouble(FloatAttr::getValueAsDouble(floatVal));
bool losesInfo = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return newDouble.bitcastToAPInt();
}));
}
// Float -> Int
return elements.mapValues(
newType,
llvm::function_ref<func_type>([&bitWidth](const APFloat& floatVal) {
return APInt(bitWidth, FloatAttr::getValueAsDouble(floatVal));
}));
}
// oldType is Integer
// mapValues always takes a function returning APInt, even when the output
// is actually float.
using func_type = APInt(const APInt&);
if (auto newFloatType = newType.dyn_cast<FloatType>()) {
// Int -> Float
return elements.mapValues(
newType,
llvm::function_ref<func_type>([&newFloatType](const APInt& intVal) {
APFloat newDouble(static_cast<double>(intVal.getSExtValue()));
bool losesInfo = false;
newDouble.convert(newFloatType.getFloatSemantics(),
llvm::APFloat::rmNearestTiesToEven, &losesInfo);
return newDouble.bitcastToAPInt();
}));
}
// newType is Integer
// Int -> Int
return elements.mapValues(
newType, llvm::function_ref<func_type>([&bitWidth](const APInt& intVal) {
return APInt(bitWidth, intVal.getSExtValue());
}));
}
} // namespace
OpFoldResult ConvertOp::fold(ArrayRef<Attribute> operands) {
if (getOperand()->getType() == getResult()->getType()) return getOperand();
// If the operand is constant, we can do the conversion now.
if (auto elementsAttr = operands.front().dyn_cast_or_null<ElementsAttr>()) {
return ConvertElements(elementsAttr, getElementTypeOrSelf(getResult()));
return xla::ConvertElementsAttr(elementsAttr,
getElementTypeOrSelf(getResult()));
}
return {};

View File

@ -490,6 +490,20 @@ def HLO_ConcatenateOp : HLO_Op<"concatenate",
let hasCustomHLOConverter = 1;
}
def HLO_CrossReplicaSumOp : HLO_Op<"cross-replica-sum",
[NoSideEffect, SameOperandsAndResultType]>, BASE_HLO_CrossReplicaSumOp {
let arguments = (ins
HLO_Tensor:$operand,
I64ElementsAttr:$replica_groups
);
let results = (outs HLO_Tensor);
// TODO(b/129422361) ConcatOp has special conversion logic to HLO.
let hasCustomHLOConverter = 1;
}
// TODO(hinsu): Make this struct dialect independent so that it can be shared
// between HLO and LHLO dialect.
def ConvDimensionNumbers : StructAttr<"ConvDimensionNumbers", HLO_Dialect, [

View File

@ -547,6 +547,22 @@ class BASE_HLO_ConcatenateOp {
}];
}
class BASE_HLO_CrossReplicaSumOp {
string summary = "Sums input across replicated instances.";
string description = [{
For each of the replica groups, operands of the group devices are summed
so that each device has the sum.
For example, suppose there are 8 TPU devices: `[A, B, C, D, E, F, G, H]`.
Passing group_assignment=`[[0,2,4,6],[1,3,5,7]]` sets `A, C, E, G` as group 0,
and `B, D, F, H` as group 1. Thus we get the outputs:
`[A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H, A+C+E+G, B+D+F+H]`.
See https://www.tensorflow.org/xla/operation_semantics#crossreplicasum.
}];
}
class BASE_HLO_ConvOp {
string summary = "Convolution operator";

View File

@ -362,6 +362,10 @@ LogicalResult ExportXlaOp(CopyOp op, OpLoweringContext ctx) {
return failure();
}
LogicalResult ExportXlaOp(CrossReplicaSumOp op, OpLoweringContext ctx) {
return failure();
}
LogicalResult ExportXlaOp(DynamicSliceOp op, OpLoweringContext ctx) {
return failure();
}

View File

@ -1670,3 +1670,15 @@ func @conv2d_backprop_filter(
} : (tensor<100x28x28x1xf32>, tensor<4xi32>, tensor<100x26x26x32xf32>) -> tensor<100x28x28x1xf32>
return %result : tensor<100x28x28x1xf32>
}
// CHECK-LABEL: @cross_replica_sum
func @cross_replica_sum(%input: tensor<10xf32>) -> tensor<10xf32> {
%replica_groups = "tf.Const" () {
value = dense<[[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi32>
} : () -> tensor<2x4xi32>
// CHECK: xla_hlo.cross-replica-sum
// CHECK-SAME: replica_groups = dense<{{\[}}[0, 2, 4, 6], [1, 3, 5, 7]]> : tensor<2x4xi64>
%result = "tf.CrossReplicaSum" (%input, %replica_groups) : (tensor<10xf32>, tensor<2x4xi32>) -> tensor<10xf32>
return %result : tensor<10xf32>
}

View File

@ -38,6 +38,7 @@ limitations under the License.
#include "mlir/Transforms/DialectConversion.h" // TF:local_config_mlir
#include "tensorflow/compiler/mlir/tensorflow/ir/tf_ops.h"
#include "tensorflow/compiler/mlir/tensorflow/transforms/lower_tf.h"
#include "tensorflow/compiler/mlir/xla/convert_op_folder.h"
#include "tensorflow/compiler/mlir/xla/ir/hlo_ops.h"
#include "tensorflow/compiler/mlir/xla/transforms/passes.h"
#include "tensorflow/core/framework/common_shape_fns.h"

View File

@ -228,6 +228,18 @@ def : Pat<(TF_ConcatV2Op $inputs, (TF_ConstOp OneElementAttr:$axis), $unused),
(HLO_ConcatenateOp $inputs, (GetHLOAxisFromTFAxis $axis, $inputs)),
[(HasRankedFirstOperand $inputs)]>;
//===----------------------------------------------------------------------===//
// CrossReplicaSum op patterns.
//===----------------------------------------------------------------------===//
def CastElementsToI64Elements : NativeCodeCall<
"::xla::ConvertElementsAttr("
"$0, $_builder.getIntegerType(64)).cast<DenseIntElementsAttr>()">;
def : Pat<(TF_CrossReplicaSumOp $input, (TF_ConstOp $group_assignment)),
(HLO_CrossReplicaSumOp $input,
(CastElementsToI64Elements $group_assignment))>;
//===----------------------------------------------------------------------===//
// Fft op patterns.
//===----------------------------------------------------------------------===//

View File

@ -145,7 +145,12 @@ StatusOr<bool> CallInliner::Run(HloModule* module) {
call_graph->VisitNodes([&](const CallGraphNode& node) -> Status {
for (const CallSite& callsite : node.caller_callsites()) {
VLOG(1) << "Visiting callsite: " << callsite.ToString();
if (callsite.instruction()->opcode() == HloOpcode::kCall) {
bool callsite_alive =
absl::c_any_of(node.callers(), [&](HloComputation* caller) {
return caller->ContainsInstruction(callsite.instruction());
});
if (callsite.instruction()->opcode() == HloOpcode::kCall &&
callsite_alive) {
HloInstruction* call = callsite.instruction();
TF_RETURN_IF_ERROR(Inline(call).status());
did_mutate = true;

View File

@ -142,6 +142,46 @@ TEST_F(CallInlinerTest, InlineWithoutRunningPass) {
ElementsAre(op::Constant()));
}
// Test that inlining can work with computations with dead parameter.
TEST_F(CallInlinerTest, InlineWithEmptyComputation) {
const Shape pred = ShapeUtil::MakeShape(PRED, {});
auto module = CreateNewVerifiedModule();
Shape r0s32 = ShapeUtil::MakeShape(S32, {});
HloComputation::Builder empty(TestName() + ".empty");
empty.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "A"));
empty.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
HloComputation* empty_computation =
module->AddEmbeddedComputation(empty.Build());
HloComputation::Builder empty2(TestName() + ".empty");
empty2.AddInstruction(HloInstruction::CreateParameter(0, r0s32, "A"));
empty2.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
HloComputation* empty2_computation =
module->AddEmbeddedComputation(empty2.Build());
HloComputation::Builder entry("entry");
auto zero = entry.AddInstruction(
HloInstruction::CreateConstant(LiteralUtil::CreateR0<int32>(0)));
// The order of the call chain are crafted to test a specific pattern such
// that the third call instruction will be flattened before the second one
// (which makes the second call instruction dead before it is flattened).
entry.AddInstruction(
HloInstruction::CreateCall(r0s32, {zero}, empty_computation));
HloInstruction* call1 = entry.AddInstruction(
HloInstruction::CreateCall(r0s32, {zero}, empty2_computation));
entry.AddInstruction(
HloInstruction::CreateCall(r0s32, {call1}, empty_computation));
auto computation = module->AddEntryComputation(entry.Build());
CallInliner call_inliner;
TF_ASSERT_OK_AND_ASSIGN(bool mutated, call_inliner.Run(module.get()));
ASSERT_TRUE(mutated);
EXPECT_THAT(computation->root_instruction(), op::Constant());
}
TEST_F(CallInlinerTest, CallToOutfeedComputationIsInlined) {
const Shape f32 = ShapeUtil::MakeShape(F32, {});
auto module = CreateNewVerifiedModule();

View File

@ -20,6 +20,7 @@ limitations under the License.
#include "tensorflow/lite/core/api/tensor_utils.h"
#include "tensorflow/lite/experimental/micro/memory_helpers.h"
#include "tensorflow/lite/experimental/micro/memory_planner/greedy_memory_planner.h"
#include "tensorflow/lite/experimental/micro/simple_memory_allocator.h"
namespace tflite {
@ -89,15 +90,29 @@ TfLiteStatus MicroAllocator::RegisterPreallocatedInput(uint8_t* buffer,
TfLiteStatus MicroAllocator::AllocateTensors() {
const size_t tensors_size = tensors_->size();
// It would be better not to allocate this memory for the lifetime of the
// model, but we don't have a straightforward way to avoid it.
TensorInfo* tensor_info =
reinterpret_cast<TensorInfo*>(memory_allocator_.AllocateFromTail(
sizeof(TensorInfo) * tensors_size, sizeof(TensorInfo)));
const flatbuffers::Vector<flatbuffers::Offset<Buffer>>* buffers =
model_->buffers();
// Initialize runtime tensors.
for (size_t i = 0; i < tensors_size; ++i) {
auto* runtime_tensor = &context_->tensors[i];
auto* flatbuffer_tensor = tensors_->Get(i);
// Preallocated inputs have already been set up earlier, so skip them.
const bool is_preallocated_input = (runtime_tensor->data.raw != nullptr);
if (!is_preallocated_input) {
TF_LITE_ENSURE_STATUS(InitializeRuntimeTensor(*flatbuffer_tensor, buffers,
error_reporter_,
runtime_tensor, nullptr));
}
}
// tensor_info is only used in this function.
auto tmp_allocator = memory_allocator_.CreateChildAllocator();
TensorInfo* tensor_info =
reinterpret_cast<TensorInfo*>(tmp_allocator.AllocateFromTail(
sizeof(TensorInfo) * tensors_size, sizeof(TensorInfo)));
// Set up the runtime data structures for all tensors.
for (size_t i = 0; i < tensors_size; ++i) {
TensorInfo* current = &tensor_info[i];
@ -112,14 +127,6 @@ TfLiteStatus MicroAllocator::AllocateTensors() {
current->last_used = -1;
}
current->needs_allocating = false;
// Preallocated inputs have already been set up earlier, so skip them.
const bool is_preallocated_input =
(current->runtime_tensor->data.raw != nullptr);
if (!is_preallocated_input) {
TF_LITE_ENSURE_STATUS(InitializeRuntimeTensor(
*current->flatbuffer_tensor, buffers, error_reporter_,
current->runtime_tensor, nullptr));
}
}
// First go through the inputs and figure out if they need to be allocated.
@ -181,8 +188,9 @@ TfLiteStatus MicroAllocator::AllocateTensors() {
uint8_t* aligned_arena = AlignPointerUp(arena_, kBufferAlignment);
const size_t alignment_loss = (aligned_arena - arena_);
// Remaining arena size that memory planner can use for calculating offsets.
int remaining_arena_size =
arena_size_ - (memory_allocator_.GetDataSize() + alignment_loss);
arena_size_ - (tmp_allocator.GetDataSize() + alignment_loss);
GreedyMemoryPlanner planner(aligned_arena, remaining_arena_size);
// Add the tensors to our allocation plan.
@ -201,8 +209,12 @@ TfLiteStatus MicroAllocator::AllocateTensors() {
}
}
// Actual size available for placing tensors. This includes memory held by the
// tensor info array, which will be released.
int actual_available_arena_size =
arena_size_ - (memory_allocator_.GetDataSize() + alignment_loss);
// Make sure we have enough room.
if (planner.GetMaximumMemorySize() > remaining_arena_size) {
if (planner.GetMaximumMemorySize() > actual_available_arena_size) {
error_reporter_->Report(
"Arena size is too small for activation buffers. Needed %d but only %d "
"was available.",

View File

@ -22,6 +22,10 @@ namespace tflite {
uint8_t* SimpleMemoryAllocator::AllocateFromTail(size_t size,
size_t alignment) {
if (has_child_allocator_) {
// TODO(wangtz): Add error reporting when the parent allocator is locked!
return nullptr;
}
uint8_t* previous_free = (data_ + data_size_max_) - data_size_;
uint8_t* current_data = previous_free - size;
uint8_t* aligned_result = AlignPointerDown(current_data, alignment);
@ -34,4 +38,21 @@ uint8_t* SimpleMemoryAllocator::AllocateFromTail(size_t size,
return aligned_result;
}
SimpleMemoryAllocator SimpleMemoryAllocator::CreateChildAllocator() {
// Note that the parameterized constructor initializes data_size_ to 0 which
// is not what we expected.
SimpleMemoryAllocator child = *this;
child.parent_allocator_ = this;
// With C++ copy elision, &child should be available after return.
has_child_allocator_ = true;
return child;
}
SimpleMemoryAllocator::~SimpleMemoryAllocator() {
// Root allocator doesn't have a parent.
if (nullptr != parent_allocator_) {
parent_allocator_->has_child_allocator_ = false;
}
}
} // namespace tflite

View File

@ -28,7 +28,7 @@ namespace tflite {
class SimpleMemoryAllocator {
public:
SimpleMemoryAllocator(uint8_t* buffer, size_t buffer_size)
: data_size_(0), data_size_max_(buffer_size), data_(buffer) {}
: data_size_max_(buffer_size), data_(buffer) {}
// Allocates memory starting at the end of the arena (highest address and
// moving downwards, so that tensor buffers can be allocated from the start
@ -37,10 +37,25 @@ class SimpleMemoryAllocator {
int GetDataSize() const { return data_size_; }
// Child allocator is something like a temporary allocator. Memory allocated
// by the child allocator will be freed once the child allocator is
// deallocated. Child allocator could be cascaded to have for example
// grandchild allocator. But at any given time, only the latest child
// allocator can be used. All its ancestors will be locked to avoid memory
// corruption. Locked means that the allocator can't allocate memory.
// WARNING: Parent allocator needs to live longer than the child allocator.
SimpleMemoryAllocator CreateChildAllocator();
// Unlocks parent allocator when the child allocator is deconstructed.
~SimpleMemoryAllocator();
private:
int data_size_;
int data_size_ = 0;
size_t data_size_max_;
uint8_t* data_;
SimpleMemoryAllocator* parent_allocator_ = nullptr;
// The allocator is locaked if it has a child.
bool has_child_allocator_ = false;
};
} // namespace tflite

View File

@ -56,4 +56,32 @@ TF_LITE_MICRO_TEST(TestMultipleTooLarge) {
TF_LITE_MICRO_EXPECT_EQ(nullptr, result);
}
TF_LITE_MICRO_TEST(TestChildAllocator) {
constexpr size_t arena_size = 1024;
uint8_t arena[arena_size];
tflite::SimpleMemoryAllocator allocator(arena, arena_size);
uint8_t* first = allocator.AllocateFromTail(16, 4);
TF_LITE_MICRO_EXPECT_NE(nullptr, first);
{
auto child_allocator = allocator.CreateChildAllocator();
uint8_t* second = child_allocator.AllocateFromTail(16, 4);
TF_LITE_MICRO_EXPECT_EQ(second, first - 16);
auto grand_child_allocator = child_allocator.CreateChildAllocator();
uint8_t* third = grand_child_allocator.AllocateFromTail(15, 4);
TF_LITE_MICRO_EXPECT_EQ(third, second - 16);
// Parent allocator is locked.
TF_LITE_MICRO_EXPECT_EQ(nullptr, allocator.AllocateFromTail(16, 4));
TF_LITE_MICRO_EXPECT_EQ(nullptr, child_allocator.AllocateFromTail(16, 4));
}
// Parent allocator is unlocked.
auto child_allocator = allocator.CreateChildAllocator();
uint8_t* fourth = child_allocator.AllocateFromTail(16, 4);
TF_LITE_MICRO_EXPECT_EQ(fourth, first - 16);
}
TF_LITE_MICRO_TESTS_END

View File

@ -212,7 +212,7 @@ void EvalAdd(TfLiteContext* context, TfLiteNode* node, TfLiteAddParams* params,
}
} else {
if (need_broadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, float);
TF_LITE_ADD(optimized_ops, BroadcastAddDispatch, float);
} else {
TF_LITE_ADD(optimized_ops, Add, float);
}
@ -256,11 +256,8 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ADD(reference_integer_ops, Add, int8_t);
}
} else {
if (op_params.broadcast_category ==
BroadcastableOpCategory::kGenericBroadcast) {
TF_LITE_ADD(reference_integer_ops, BroadcastAdd4DSlow, int8_t);
} else if (need_broadcast) {
TF_LITE_ADD(optimized_integer_ops, BroadcastAddFivefold, int8_t);
if (need_broadcast) {
TF_LITE_ADD(optimized_integer_ops, BroadcastAddDispatch, int8_t);
} else {
TF_LITE_ADD(optimized_integer_ops, Add, int8_t);
}
@ -273,11 +270,8 @@ TfLiteStatus EvalAddQuantized(TfLiteContext* context, TfLiteNode* node,
TF_LITE_ADD(reference_ops, Add, uint8_t);
}
} else {
if (op_params.broadcast_category ==
BroadcastableOpCategory::kGenericBroadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAdd4DSlow, uint8_t);
} else if (need_broadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAddFivefold, uint8_t);
if (need_broadcast) {
TF_LITE_ADD(optimized_ops, BroadcastAddDispatch, uint8_t);
} else {
TF_LITE_ADD(optimized_ops, Add, uint8_t);
}

View File

@ -139,6 +139,68 @@ TEST(FloatAddOpModel, WithBroadcast) {
}
}
TEST(FloatAddOpModel, WithBroadcastGeneric) {
std::vector<int> test_shape1 = {1, 3, 1};
std::vector<int> test_shape2 = {2, 1, 2};
FloatAddOpModel m({TensorType_FLOAT32, test_shape1},
{TensorType_FLOAT32, test_shape2}, {TensorType_FLOAT32, {}},
ActivationFunctionType_NONE);
m.PopulateTensor<float>(m.input1(), {0.1, 0.2, 0.3});
m.PopulateTensor<float>(m.input2(), {0.1, 0.2, 0.3, 0.4});
m.Invoke();
EXPECT_THAT(m.GetOutput(),
ElementsAreArray(ArrayFloatNear({0.2, 0.3, 0.3, 0.4, 0.4, 0.5,
0.4, 0.5, 0.5, 0.6, 0.6, 0.7})));
}
TEST(FloatAddOpModel, MixedBroadcast) {
const std::vector<int> base_shape = {2, 3, 1, 2};
std::vector<std::vector<int>> test_shapes = {
{1, 1, 3, 2}, {1, 3, 1, 2}, {2, 1, 3, 1}, {2, 3, 1, 1}};
std::vector<std::vector<float>> test_outputs = {
{-0.1f, 2.6f, -0.7f, 2.8f, 0.7f, 3.2f, 1.1f, 0.8f, 0.5f,
1.0f, 1.9f, 1.4f, 1.0f, -0.8f, 0.4f, -0.6f, 1.8f, -0.2f,
1.4f, 3.1f, 0.8f, 3.3f, 2.2f, 3.7f, -1.4f, 0.3f, -2.0f,
0.5f, -0.6f, 0.9f, 0.9f, -1.9f, 0.3f, -1.7f, 1.7f, -1.3f},
{-0.1f, 2.6f, 0.5f, 1.0f, 1.8f, -0.2f, 1.4f, 3.1f, -2.0f, 0.5f, 1.7f,
-1.3f},
{-0.1f, 2.5f, 0.0f, 2.6f, -0.7f, 1.9f, 1.1f, 0.7f, 1.2f,
0.8f, 0.5f, 0.1f, 1.0f, -0.9f, 1.1f, -0.8f, 0.4f, -1.5f,
1.7f, 3.3f, 2.2f, 3.8f, 2.1f, 3.7f, -1.1f, 0.5f, -0.6f,
1.0f, -0.7f, 0.9f, 1.2f, -1.7f, 1.7f, -1.2f, 1.6f, -1.3f},
{-0.1f, 2.5f, 1.2f, 0.8f, 0.4f, -1.5f, 1.7f, 3.3f, -0.6f, 1.0f, 1.6f,
-1.3f}};
for (size_t i = 0; i < test_shapes.size(); ++i) {
FloatAddOpModel model_fixture(
{TensorType_FLOAT32, base_shape}, {TensorType_FLOAT32, test_shapes[i]},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
model_fixture.PopulateTensor<float>(
model_fixture.input1(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f,
2.8f, -1.6f, 0.0f, 0.7f, -2.2f});
model_fixture.PopulateTensor<float>(model_fixture.input2(),
{0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f});
model_fixture.Invoke();
EXPECT_THAT(model_fixture.GetOutput(),
ElementsAreArray(ArrayFloatNear(test_outputs[i], 0.0001f)))
<< "With shape number " << i;
}
// Re-run with exchanged inputs.
for (size_t i = 0; i < test_shapes.size(); ++i) {
FloatAddOpModel model_fixture(
{TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, base_shape},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
model_fixture.PopulateTensor<float>(model_fixture.input1(),
{0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f});
model_fixture.PopulateTensor<float>(
model_fixture.input2(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f,
2.8f, -1.6f, 0.0f, 0.7f, -2.2f});
model_fixture.Invoke();
EXPECT_THAT(model_fixture.GetOutput(),
ElementsAreArray(ArrayFloatNear(test_outputs[i], 0.0001f)))
<< "With shape number " << i;
}
}
TEST(IntegerAddOpModel, NoActivation) {
IntegerAddOpModel m({TensorType_INT32, {1, 2, 2, 1}},
{TensorType_INT32, {1, 2, 2, 1}}, {TensorType_INT32, {}},
@ -435,5 +497,31 @@ TEST(QuantizedAddOpModel, QuantizedWithMixedBroadcastInt8) {
QuantizedWithMixedBroadcast<TensorType_INT8, int8_t>();
}
template <enum TensorType tensor_type, typename integer_dtype>
void QuantizedWithGenericBroadcast() {
float kQuantizedTolerance = GetTolerance(-1.0, 1.0);
std::vector<int> test_shape1 = {1, 3, 1};
std::vector<int> test_shape2 = {2, 1, 2};
QuantizedAddOpModel m({tensor_type, test_shape1, -1.0, 1.0},
{tensor_type, test_shape2, -1.0, 1.0},
{tensor_type, {}, -1.0, 1.0},
ActivationFunctionType_NONE);
m.QuantizeAndPopulate<integer_dtype>(m.input1(), {0.1, 0.2, 0.3});
m.QuantizeAndPopulate<integer_dtype>(m.input2(), {0.1, -0.2, 0.3, -0.4});
m.Invoke();
EXPECT_THAT(m.GetDequantizedOutput<integer_dtype>(),
ElementsAreArray(ArrayFloatNear({0.2, -0.1, 0.3, 0., 0.4, 0.1,
0.4, -0.3, 0.5, -0.2, 0.6, -0.1},
kQuantizedTolerance)));
}
TEST(QuantizedAddOpModel, QuantizedWithGenericBroadcastUInt8) {
QuantizedWithGenericBroadcast<TensorType_UINT8, uint8_t>();
}
TEST(QuantizedAddOpModel, QuantizedWithGenericdBroadcastInt8) {
QuantizedWithGenericBroadcast<TensorType_INT8, int8_t>();
}
} // namespace
} // namespace tflite

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "profiling/instrumentation.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/add.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -325,6 +326,23 @@ inline void BroadcastAddFivefold(const ArithmeticParams& unswitched_params,
}
}
inline void BroadcastAddDispatch(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int8* input1_data,
const RuntimeShape& input2_shape,
const int8* input2_data,
const RuntimeShape& output_shape,
int8* output_data) {
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
return reference_integer_ops::BroadcastAdd4DSlow(
params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
}
BroadcastAddFivefold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
} // namespace optimized_integer_ops
} // namespace tflite

View File

@ -18,6 +18,7 @@ limitations under the License.
#include "profiling/instrumentation.h"
#include "tensorflow/lite/kernels/internal/common.h"
#include "tensorflow/lite/kernels/internal/optimized/cpu_check.h"
#include "tensorflow/lite/kernels/internal/reference/integer_ops/mul.h"
#include "tensorflow/lite/kernels/internal/types.h"
namespace tflite {
@ -251,6 +252,23 @@ inline void BroadcastMulFivefold(const ArithmeticParams& unswitched_params,
}
}
inline void BroadcastMulDispatch(const ArithmeticParams& params,
const RuntimeShape& input1_shape,
const int8* input1_data,
const RuntimeShape& input2_shape,
const int8* input2_data,
const RuntimeShape& output_shape,
int8* output_data) {
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
return reference_integer_ops::BroadcastMul4DSlow(
params, input1_shape, input1_data, input2_shape, input2_data,
output_shape, output_data);
}
BroadcastMulFivefold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
} // namespace optimized_integer_ops
} // namespace tflite

View File

@ -2108,6 +2108,20 @@ inline void BroadcastAddFivefold(const ArithmeticParams& params,
}
}
template <typename T>
inline void BroadcastAddDispatch(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
return BroadcastAdd4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
BroadcastAddFivefold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
inline void MulElementwise(int size, const ArithmeticParams& params,
const float* input1_data, const float* input2_data,
float* output_data) {
@ -2601,6 +2615,20 @@ inline void BroadcastMulFivefold(const ArithmeticParams& params,
}
}
template <typename T>
inline void BroadcastMulDispatch(
const ArithmeticParams& params, const RuntimeShape& input1_shape,
const T* input1_data, const RuntimeShape& input2_shape,
const T* input2_data, const RuntimeShape& output_shape, T* output_data) {
if (params.broadcast_category == BroadcastableOpCategory::kGenericBroadcast) {
return BroadcastMul4DSlow(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
BroadcastMulFivefold(params, input1_shape, input1_data, input2_shape,
input2_data, output_shape, output_data);
}
// TODO(jiawen): We can implement BroadcastDiv on buffers of arbitrary
// dimensionality if the runtime code does a single loop over one dimension
// that handles broadcasting as the base case. The code generator would then

View File

@ -146,7 +146,7 @@ void EvalMul(TfLiteContext* context, TfLiteNode* node, TfLiteMulParams* params,
}
} else {
if (need_broadcast) {
TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, float);
TF_LITE_MUL(optimized_ops, BroadcastMulDispatch, float);
} else {
TF_LITE_MUL(optimized_ops, Mul, float);
}
@ -186,7 +186,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
} else {
if (need_broadcast) {
TF_LITE_MUL(optimized_integer_ops, BroadcastMulFivefold, int8_t);
TF_LITE_MUL(optimized_integer_ops, BroadcastMulDispatch, int8_t);
} else {
TF_LITE_MUL(optimized_integer_ops, Mul, int8_t);
}
@ -201,7 +201,7 @@ TfLiteStatus EvalQuantized(TfLiteContext* context, TfLiteNode* node,
}
} else {
if (need_broadcast) {
TF_LITE_MUL(optimized_ops, BroadcastMulFivefold, uint8_t);
TF_LITE_MUL(optimized_ops, BroadcastMulDispatch, uint8_t);
} else {
TF_LITE_MUL(optimized_ops, Mul, uint8_t);
}

View File

@ -159,6 +159,55 @@ TEST(FloatMulOpTest, WithBroadcast) {
}
}
TEST(FloatMulOpTest, MixedBroadcast) {
const std::vector<int> base_shape = {2, 3, 1, 2};
std::vector<std::vector<int>> test_shapes = {
{1, 1, 3, 2}, {1, 3, 1, 2}, {2, 1, 3, 1}, {2, 3, 1, 1}};
std::vector<std::vector<float>> test_outputs = {
{-0.06f, 0.69f, 0.12f, 1.15f, -0.30f, 2.07f, 0.18f, 0.15f, -0.36f,
0.25f, 0.90f, 0.45f, 0.16f, -0.33f, -0.32f, -0.55f, 0.80f, -0.99f,
0.24f, 0.84f, -0.48f, 1.40f, 1.20f, 2.52f, -0.32f, 0.00f, 0.64f,
0.00f, -1.60f, 0.00f, 0.14f, -0.66f, -0.28f, -1.10f, 0.70f, -1.98f},
{-0.06f, 0.69f, -0.36f, 0.25f, 0.80f, -0.99f, 0.24f, 0.84f, 0.64f, 0.00f,
0.70f, -1.98f},
{-0.06f, 0.46f, -0.09f, 0.69f, 0.12f, -0.92f, 0.18f, 0.10f, 0.27f,
0.15f, -0.36f, -0.20f, 0.16f, -0.22f, 0.24f, -0.33f, -0.32f, 0.44f,
0.60f, 1.40f, 1.20f, 2.80f, 1.08f, 2.52f, -0.80f, 0.00f, -1.60f,
0.00f, -1.44f, 0.00f, 0.35f, -1.10f, 0.70f, -2.20f, 0.63f, -1.98f},
{-0.06f, 0.46f, 0.27f, 0.15f, -0.32f, 0.44f, 0.60f, 1.40f, -1.60f, 0.00f,
0.63f, -1.98f}};
for (size_t i = 0; i < test_shapes.size(); ++i) {
FloatMulOpModel model_fixture(
{TensorType_FLOAT32, base_shape}, {TensorType_FLOAT32, test_shapes[i]},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
model_fixture.PopulateTensor<float>(
model_fixture.input1(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f,
2.8f, -1.6f, 0.0f, 0.7f, -2.2f});
model_fixture.PopulateTensor<float>(model_fixture.input2(),
{0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f});
model_fixture.Invoke();
EXPECT_THAT(model_fixture.GetOutput(),
ElementsAreArray(ArrayFloatNear(test_outputs[i], 0.0001f)))
<< "With shape number " << i;
}
// Re-run with exchanged inputs.
for (size_t i = 0; i < test_shapes.size(); ++i) {
FloatMulOpModel model_fixture(
{TensorType_FLOAT32, test_shapes[i]}, {TensorType_FLOAT32, base_shape},
{TensorType_FLOAT32, {}}, ActivationFunctionType_NONE);
model_fixture.PopulateTensor<float>(model_fixture.input1(),
{0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f});
model_fixture.PopulateTensor<float>(
model_fixture.input2(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f,
2.8f, -1.6f, 0.0f, 0.7f, -2.2f});
model_fixture.Invoke();
EXPECT_THAT(model_fixture.GetOutput(),
ElementsAreArray(ArrayFloatNear(test_outputs[i], 0.0001f)))
<< "With shape number " << i;
}
}
TEST(FloatMulOpTest, WithBroadcast2Elements) {
std::vector<std::vector<int>> test_shapes = {
{2, 2}, {2, 1, 2}, {1, 2, 2}, {1, 2, 1, 2}};
@ -342,6 +391,60 @@ void WithBroadcast() {
}
}
template <enum TensorType tensor_type, typename integer_dtype>
void QuantizedWithMixedBroadcast() {
float kQuantizedTolerance = GetTolerance(-3.f, 3.f);
const std::vector<int> base_shape = {2, 3, 1, 2};
std::vector<std::vector<int>> test_shapes = {
{1, 1, 3, 2}, {1, 3, 1, 2}, {2, 1, 3, 1}, {2, 3, 1, 1}};
std::vector<std::vector<float>> test_outputs = {
{-0.06f, 0.69f, 0.12f, 1.15f, -0.30f, 2.07f, 0.18f, 0.15f, -0.36f,
0.25f, 0.90f, 0.45f, 0.16f, -0.33f, -0.32f, -0.55f, 0.80f, -0.99f,
0.24f, 0.84f, -0.48f, 1.40f, 1.20f, 2.52f, -0.32f, 0.00f, 0.64f,
0.00f, -1.60f, 0.00f, 0.14f, -0.66f, -0.28f, -1.10f, 0.70f, -1.98f},
{-0.06f, 0.69f, -0.36f, 0.25f, 0.80f, -0.99f, 0.24f, 0.84f, 0.64f, 0.00f,
0.70f, -1.98f},
{-0.06f, 0.46f, -0.09f, 0.69f, 0.12f, -0.92f, 0.18f, 0.10f, 0.27f,
0.15f, -0.36f, -0.20f, 0.16f, -0.22f, 0.24f, -0.33f, -0.32f, 0.44f,
0.60f, 1.40f, 1.20f, 2.80f, 1.08f, 2.52f, -0.80f, 0.00f, -1.60f,
0.00f, -1.44f, 0.00f, 0.35f, -1.10f, 0.70f, -2.20f, 0.63f, -1.98f},
{-0.06f, 0.46f, 0.27f, 0.15f, -0.32f, 0.44f, 0.60f, 1.40f, -1.60f, 0.00f,
0.63f, -1.98f}};
for (size_t i = 0; i < test_shapes.size(); ++i) {
QuantizedMulOpModel model_fixture({tensor_type, base_shape, -3.f, 3.f},
{tensor_type, test_shapes[i], -3.f, 3.f},
{tensor_type, {}, -3.f, 3.f},
ActivationFunctionType_NONE);
model_fixture.QuantizeAndPopulate<integer_dtype>(
model_fixture.input1(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f,
2.8f, -1.6f, 0.0f, 0.7f, -2.2f});
model_fixture.QuantizeAndPopulate<integer_dtype>(
model_fixture.input2(), {0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f});
model_fixture.Invoke();
EXPECT_THAT(
model_fixture.GetDequantizedOutput<integer_dtype>(),
ElementsAreArray(ArrayFloatNear(test_outputs[i], kQuantizedTolerance)))
<< "With shape number " << i;
}
// Re-run with exchanged inputs.
for (size_t i = 0; i < test_shapes.size(); ++i) {
QuantizedMulOpModel model_fixture({tensor_type, test_shapes[i], -3.f, 3.f},
{tensor_type, base_shape, -3.f, 3.f},
{tensor_type, {}, -3.f, 3.f},
ActivationFunctionType_NONE);
model_fixture.QuantizeAndPopulate<integer_dtype>(
model_fixture.input1(), {0.2f, 0.3f, -0.4f, 0.5f, 1.0f, 0.9f});
model_fixture.QuantizeAndPopulate<integer_dtype>(
model_fixture.input2(), {-0.3f, 2.3f, 0.9f, 0.5f, 0.8f, -1.1f, 1.2f,
2.8f, -1.6f, 0.0f, 0.7f, -2.2f});
model_fixture.Invoke();
EXPECT_THAT(
model_fixture.GetDequantizedOutput<integer_dtype>(),
ElementsAreArray(ArrayFloatNear(test_outputs[i], kQuantizedTolerance)))
<< "With shape number " << i;
}
}
TEST(QuantizedMulOpTest, WithBroadcastUInt8) {
WithBroadcast<TensorType_UINT8, uint8_t>();
}
@ -350,5 +453,13 @@ TEST(QuantizedMulOpTest, WithBroadcastInt8) {
WithBroadcast<TensorType_INT8, int8_t>();
}
TEST(QuantizedMulOpTest, QuantizedWithMixedBroadcastUInt8) {
QuantizedWithMixedBroadcast<TensorType_UINT8, uint8_t>();
}
TEST(QuantizedMulOpTest, QuantizedWithMixedBroadcastInt8) {
QuantizedWithMixedBroadcast<TensorType_INT8, int8_t>();
}
} // namespace
} // namespace tflite

View File

@ -107,7 +107,7 @@ def op_priority(op_type):
Integer value corresponding the priority of the op.
"""
if op_type in ('Const', 'Shape', 'BroadcastGradientArgs', 'Range',
'VariableShape', 'Fill', 'OneHot'):
'VariableShape', 'Fill', 'OneHot', 'ShapeN'):
# Lowest priority ops, e.g., constant ops accross different steps,
# They will be traced only if trace_level>=7
return 7

View File

@ -649,7 +649,15 @@ def outside_compilation(computation, *args, **kwargs):
# we need to attach _xla_outside_compilation attribute directly because we are
# not in TPUReplicateContext.
if isinstance(graph, func_graph.FuncGraph):
tpu_context, _ = _enclosing_tpu_context_and_graph()
try:
tpu_context, _ = _enclosing_tpu_context_and_graph()
except ValueError:
logging.warning(
"Outside compilation attempted outside TPUReplicateContext "
"scope. As no enclosing TPUReplicateContext can be found, "
"returning the result of `computation` as is.")
return computation(*args, **kwargs)
# pylint: disable=protected-access
outside_compilation_name = str(tpu_context._outside_compilation_counter)
tpu_context._outside_compilation_counter = (

View File

@ -78,8 +78,9 @@ function update_bazel_linux {
popd
PATH="/home/kbuilder/bin:$PATH"
set_bazel_outdir
which bazel
bazel version
}
# LINT.ThenChange(
# //tensorflow_estimator/google/kokoro/common.sh)
@ -99,6 +100,9 @@ function update_bazel_macos {
run_with_retry "${BAZEL_COMMAND}"
# Add new bazel installation to path
PATH="/Users/kbuilder/bin:$PATH"
set_bazel_outdir
which bazel
bazel version
}
function install_pip2 {