Merged commit includes the following changes:
313738015 by A. Unique TensorFlower<gardener@tensorflow.org>: Bump open source llvm revision to b726d071b4aa46004228fc38ee5bfd167f999bfe -- 313737890 by A. Unique TensorFlower<gardener@tensorflow.org>: Automated rollback of changelist 313718130. 313733429 by A. Unique TensorFlower<gardener@tensorflow.org>: Automated rollback of changelist 313729562. 313729562 by A. Unique TensorFlower<gardener@tensorflow.org>: [TF:STATELES_RNG] clarify that the same output of stateless rng is only guaranteed for the same shape and seed. -- 313718732 by A. Unique TensorFlower<gardener@tensorflow.org>: [XLA:SPMD] Handle window reversal in backprop filter conv -- 313718302 by A. Unique TensorFlower<gardener@tensorflow.org>: [Core ML Delegate] Add FP16 support for Convolution -- 313718156 by A. Unique TensorFlower<gardener@tensorflow.org>: Integrate LLVM at https://github.com/llvm/llvm-project/commit/b726d071b4aa -- PiperOrigin-RevId: 313738015
This commit is contained in:
parent
a1ae008076
commit
4de4c60972
tensorflow
compiler/xla/service/spmd
spmd_partitioner.ccspmd_partitioner.hspmd_partitioner_test.ccspmd_partitioner_util.ccspmd_partitioner_util.h
lite
delegates
experimental/delegates/coreml
python/keras/engine
workspace.bzlthird_party/mlir
@ -308,7 +308,8 @@ PartitionedHlo PartitionedHlo::ReshardNoCache(const HloSharding& target) {
|
||||
return PartitionedHlo(slice, base_shape_, state_);
|
||||
}
|
||||
|
||||
PartitionedHlo PartitionedHlo::PadWithValue(HloInstruction* pad_value) const {
|
||||
PartitionedHlo PartitionedHlo::PadWithValue(
|
||||
HloInstruction* pad_value, absl::Span<const int64> left_padded_dims) const {
|
||||
const HloSharding& sharding = hlo_->sharding();
|
||||
const Shape& shape = hlo_->shape();
|
||||
CHECK(!shape.IsTuple() && shape.element_type() != TOKEN);
|
||||
@ -327,13 +328,20 @@ PartitionedHlo PartitionedHlo::PadWithValue(HloInstruction* pad_value) const {
|
||||
auto index_in_full_shape =
|
||||
state_.b->AddInstruction(HloInstruction::CreateBinary(
|
||||
index_shape, HloOpcode::kAdd, iota, broadcast_start_index));
|
||||
auto valid_size = state_.b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(base_shape_.dimensions(dim))));
|
||||
auto broadcast_valid_size = state_.b->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(index_shape, valid_size, {}));
|
||||
ComparisonDirection direction = ComparisonDirection::kLt;
|
||||
int64 index_limit = base_shape_.dimensions(dim);
|
||||
if (absl::c_linear_search(left_padded_dims, dim)) {
|
||||
direction = ComparisonDirection::kGe;
|
||||
index_limit =
|
||||
index_shape.dimensions(dim) * sharding.tile_assignment().dim(dim) -
|
||||
index_limit;
|
||||
}
|
||||
auto limit = state_.b->AddInstruction(HloInstruction::CreateConstant(
|
||||
LiteralUtil::CreateR0<int32>(index_limit)));
|
||||
auto broadcast_limit = state_.b->AddInstruction(
|
||||
HloInstruction::CreateBroadcast(index_shape, limit, {}));
|
||||
return state_.b->AddInstruction(HloInstruction::CreateCompare(
|
||||
mask_shape, index_in_full_shape, broadcast_valid_size,
|
||||
ComparisonDirection::kLt));
|
||||
mask_shape, index_in_full_shape, broadcast_limit, direction));
|
||||
};
|
||||
|
||||
HloInstruction* mask = nullptr;
|
||||
@ -2328,39 +2336,14 @@ Status SpmdPartitioningVisitor::HandleReverse(HloInstruction* hlo) {
|
||||
auto operand = GetPartitionedHlo(reverse->operand(0))
|
||||
.Reshard(hlo_sharding_util::ReverseSharding(
|
||||
reverse->sharding(), reverse->dimensions()));
|
||||
// Create a window config to halo exchange for unevenly partitioned reverse
|
||||
// dimensions.
|
||||
Window window;
|
||||
for (int64 i = 0; i < hlo->shape().rank(); ++i) {
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_size(1);
|
||||
dim->set_stride(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_window_reversal(false);
|
||||
int64 low_padding = 0;
|
||||
if (absl::c_linear_search(reverse->dimensions(), i)) {
|
||||
low_padding =
|
||||
RoundUpToNearest(reverse->shape().dimensions(i),
|
||||
reverse->sharding().tile_assignment().dim(i)) -
|
||||
reverse->shape().dimensions(i);
|
||||
}
|
||||
dim->set_padding_low(low_padding);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_base_dilation(1);
|
||||
}
|
||||
|
||||
auto reshard_operand = operand.ReshardAsWindowedInput(
|
||||
window, operand.sharding(),
|
||||
CreateZero(ShapeUtil::MakeShape(hlo->shape().element_type(), {}), &b_),
|
||||
/*mask_invalid_region=*/false);
|
||||
if (!reshard_operand.has_value()) {
|
||||
auto left_padded_operand =
|
||||
HaloExchangeToPadOnLeft(operand, reverse->dimensions());
|
||||
if (!left_padded_operand) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
TF_RET_CHECK(!reshard_operand->dynamic_slice_index_on_output.has_value());
|
||||
SetPartitionedHlo(hlo, [&] {
|
||||
return b_.AddInstruction(
|
||||
hlo->CloneWithNewOperands(reshard_operand->sharded_input->shape(),
|
||||
{reshard_operand->sharded_input}));
|
||||
return b_.AddInstruction(hlo->CloneWithNewOperands(
|
||||
left_padded_operand->shape(), {left_padded_operand}));
|
||||
});
|
||||
return Status::OK();
|
||||
}
|
||||
@ -2772,10 +2755,31 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
for (int64 i = 0; i < rhs_to_lhs_indices.size(); ++i) {
|
||||
lhs_to_rhs_indices[rhs_to_lhs_indices[i]] = i;
|
||||
}
|
||||
auto aligned_rhs_sharding =
|
||||
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices);
|
||||
auto aligned_lhs_sharding =
|
||||
hlo_sharding_util::TransposeSharding(rhs.sharding(), lhs_to_rhs_indices);
|
||||
|
||||
Window window = hlo->window();
|
||||
std::vector<int64> reversed_rhs_dims;
|
||||
for (int64 i = 0; i < window.dimensions_size(); ++i) {
|
||||
if (window.dimensions(i).window_reversal()) {
|
||||
reversed_rhs_dims.push_back(dnums.kernel_spatial_dimensions(i));
|
||||
}
|
||||
}
|
||||
if (!reversed_rhs_dims.empty()) {
|
||||
// Make the reversed dims left-padded to prepare for window reversal.
|
||||
auto left_padded_rhs = HaloExchangeToPadOnLeft(rhs, reversed_rhs_dims);
|
||||
if (left_padded_rhs == nullptr) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
left_padded_rhs->set_sharding(rhs.sharding());
|
||||
rhs = PartitionedHlo(left_padded_rhs, rhs.base_shape(), rhs.state());
|
||||
}
|
||||
// Consider window reversal when resharding RHS or LHS. Note: this will not
|
||||
// reverse the data in the shard. We use window reversal to do that.
|
||||
auto aligned_rhs_sharding = hlo_sharding_util::ReverseSharding(
|
||||
hlo_sharding_util::TransposeSharding(lhs.sharding(), rhs_to_lhs_indices),
|
||||
reversed_rhs_dims);
|
||||
auto aligned_lhs_sharding = hlo_sharding_util::TransposeSharding(
|
||||
hlo_sharding_util::ReverseSharding(rhs.sharding(), reversed_rhs_dims),
|
||||
lhs_to_rhs_indices);
|
||||
|
||||
auto unsupported_sharding = [&](const HloSharding& lhs_sharding,
|
||||
const HloSharding& rhs_sharding) {
|
||||
@ -2792,13 +2796,14 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
lhs = lhs.Reshard(aligned_lhs_sharding).PadWithValue(zero);
|
||||
rhs = rhs.PadWithValue(zero);
|
||||
rhs = rhs.PadWithValue(zero, reversed_rhs_dims);
|
||||
} else {
|
||||
if (unsupported_sharding(lhs.sharding(), aligned_rhs_sharding)) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
lhs = lhs.PadWithValue(zero);
|
||||
rhs = rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero);
|
||||
rhs =
|
||||
rhs.Reshard(aligned_rhs_sharding).PadWithValue(zero, reversed_rhs_dims);
|
||||
}
|
||||
|
||||
// Reshard LHS by exchanging halo such that each shard computes the partial
|
||||
@ -2817,8 +2822,6 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
// = (LHS - RHS) * i + low_padding
|
||||
// * right-halo: limit(i) - (i + 1) * LHS
|
||||
// = [{(RHS - 1) * D + 1} - LHS] * (i + 1) + (WC - 1) * stride - low_padding
|
||||
|
||||
Window window = hlo->window();
|
||||
std::vector<int64> shard_counts(dnums.input_spatial_dimensions_size());
|
||||
std::vector<int64> lhs_shard_sizes(dnums.input_spatial_dimensions_size());
|
||||
std::vector<int64> rhs_shard_sizes(dnums.input_spatial_dimensions_size());
|
||||
@ -2827,7 +2830,7 @@ Status SpmdPartitioningVisitor::HandleConvolutionTiledLhsAndRhs(
|
||||
int64 rhs_dimension = dnums.kernel_spatial_dimensions(i);
|
||||
int64 shard_count = lhs.sharding().tile_assignment().dim(lhs_dimension);
|
||||
auto wd = window.dimensions(i);
|
||||
if (wd.base_dilation() != 1 || wd.window_reversal()) {
|
||||
if (wd.base_dilation() != 1) {
|
||||
return DefaultAction(hlo);
|
||||
}
|
||||
|
||||
|
@ -243,8 +243,13 @@ class PartitionedHlo {
|
||||
// the reshard cache.
|
||||
PartitionedHlo Reshard(const HloSharding& target);
|
||||
|
||||
// Pads the garbage area of the output with the provided value.
|
||||
PartitionedHlo PadWithValue(HloInstruction* pad_value) const;
|
||||
// Pads the garbage area of the output with the provided value. Normally,
|
||||
// unevenly partitioned dimensions are padded on the right, but this function
|
||||
// allows specifying left-padded dimensions, which can be used during the
|
||||
// handling of kReverse, etc.
|
||||
PartitionedHlo PadWithValue(
|
||||
HloInstruction* pad_value,
|
||||
absl::Span<const int64> left_padded_dims = {}) const;
|
||||
|
||||
// Returns the SPMD instruction.
|
||||
HloInstruction* hlo() const { return hlo_; }
|
||||
@ -263,6 +268,8 @@ class PartitionedHlo {
|
||||
const Window& window, const HloSharding& target,
|
||||
HloInstruction* pad_value, bool mask_invalid_region = true);
|
||||
|
||||
const PartitioningState& state() const { return state_; }
|
||||
|
||||
private:
|
||||
// Same as Reshard except that it does not explicitly modify the reshard
|
||||
// cache, although it would indirectly modify by calling Replicate().
|
||||
|
@ -1300,6 +1300,35 @@ ENTRY entry {
|
||||
op::Shape("f32[1,1,64,256]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, ConvolutionLhsTiledRhsTiledWindowReversal) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
||||
ENTRY entry {
|
||||
%lhs = f32[5,128,64] parameter(0), sharding={devices=[2,1,1]0,1}
|
||||
%rhs = f32[5,128,256] parameter(1), sharding={devices=[2,1,1]1,0}
|
||||
ROOT %conv = f32[1,64,256] convolution(%lhs, %rhs),
|
||||
window={size=5 rhs_reversal=1}, dim_labels=0fb_0io->0bf,
|
||||
sharding={replicated}
|
||||
})";
|
||||
|
||||
TF_ASSERT_OK_AND_ASSIGN(auto module,
|
||||
PartitionComputation(hlo_string, /*num_devices=*/2));
|
||||
VLOG(1) << module->ToString();
|
||||
|
||||
auto lhs_masked =
|
||||
AllOf(op::Shape("f32[3,128,64]"), op::Select(_, op::Parameter(0), _));
|
||||
auto rhs_left_padded = op::Slice(op::Concatenate(
|
||||
op::CollectivePermute(op::Slice(op::Parameter(1))), op::Parameter(1)));
|
||||
auto rhs_masked =
|
||||
AllOf(op::Shape("f32[3,128,256]"), op::Select(_, rhs_left_padded, _));
|
||||
|
||||
auto root = module->entry_computation()->root_instruction();
|
||||
EXPECT_THAT(root,
|
||||
AllOf(op::AllReduce(op::Convolution(lhs_masked, rhs_masked)),
|
||||
op::Shape("f32[1,64,256]")));
|
||||
}
|
||||
|
||||
TEST_F(SpmdPartitioningTest, DotLhsTiledRhsTiledWithReshard) {
|
||||
const char* const hlo_string = R"(
|
||||
HloModule module
|
||||
|
@ -664,5 +664,43 @@ absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
|
||||
return valid_slice;
|
||||
}
|
||||
|
||||
HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
|
||||
absl::Span<const int64> dims) {
|
||||
if (original.sharding().IsTileMaximal()) {
|
||||
return original.hlo();
|
||||
}
|
||||
// Create a window config to halo exchange for unevenly partitioned reverse
|
||||
// dimensions.
|
||||
Window window;
|
||||
for (int64 i = 0; i < original.base_shape().rank(); ++i) {
|
||||
WindowDimension* dim = window.add_dimensions();
|
||||
dim->set_size(1);
|
||||
dim->set_stride(1);
|
||||
dim->set_window_dilation(1);
|
||||
dim->set_window_reversal(false);
|
||||
int64 low_padding = 0;
|
||||
if (absl::c_linear_search(dims, i)) {
|
||||
low_padding =
|
||||
RoundUpToNearest(original.base_shape().dimensions(i),
|
||||
original.sharding().tile_assignment().dim(i)) -
|
||||
original.base_shape().dimensions(i);
|
||||
}
|
||||
dim->set_padding_low(low_padding);
|
||||
dim->set_padding_high(0);
|
||||
dim->set_base_dilation(1);
|
||||
}
|
||||
|
||||
auto reshard_window = original.ReshardAsWindowedInput(
|
||||
window, original.sharding(),
|
||||
CreateZero(ShapeUtil::MakeShape(original.base_shape().element_type(), {}),
|
||||
original.state().b),
|
||||
/*mask_invalid_region=*/false);
|
||||
if (!reshard_window.has_value()) {
|
||||
return nullptr;
|
||||
}
|
||||
CHECK(!reshard_window->dynamic_slice_index_on_output.has_value());
|
||||
return reshard_window->sharded_input;
|
||||
}
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
@ -227,6 +227,13 @@ absl::optional<HloInstruction*> ExchangeHaloAndGetValidData(
|
||||
const SPMDCollectiveOpsCreator& collective_ops_creator,
|
||||
int64* next_channel_id, SpmdBuilder* b, bool mask_invalid_region = true);
|
||||
|
||||
// Uses halo exchange to change from right-padding to left-padding for uneven
|
||||
// tiled sharding on the given dimensions. Tiled sharding always pads uneven
|
||||
// partitioned data on the right, but we need to swap it to the left for
|
||||
// kReverse or kConvolution with window reversal.
|
||||
HloInstruction* HaloExchangeToPadOnLeft(PartitionedHlo& original,
|
||||
absl::Span<const int64> dims);
|
||||
|
||||
} // namespace spmd
|
||||
} // namespace xla
|
||||
|
||||
|
@ -32,6 +32,7 @@ cc_library(
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates:utils",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
"//tensorflow/lite/nnapi:nnapi_lib",
|
||||
@ -68,6 +69,7 @@ cc_library(
|
||||
"//tensorflow/lite:minimal_logging",
|
||||
"//tensorflow/lite:util",
|
||||
"//tensorflow/lite/c:common",
|
||||
"//tensorflow/lite/delegates:utils",
|
||||
"//tensorflow/lite/kernels:kernel_util",
|
||||
"//tensorflow/lite/nnapi:nnapi_implementation",
|
||||
"//tensorflow/lite/nnapi:nnapi_lib",
|
||||
|
@ -53,6 +53,7 @@ limitations under the License.
|
||||
#include "tensorflow/lite/context_util.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/nnapi_delegate_kernel.h"
|
||||
#include "tensorflow/lite/delegates/nnapi/quant_lstm_sup.h"
|
||||
#include "tensorflow/lite/delegates/utils.h"
|
||||
#include "tensorflow/lite/kernels/kernel_util.h"
|
||||
#include "tensorflow/lite/minimal_logging.h"
|
||||
#include "tensorflow/lite/nnapi/nnapi_implementation.h"
|
||||
@ -4178,17 +4179,6 @@ int StatefulNnApiDelegate::GetNnApiErrno() const {
|
||||
using ::tflite::delegate::nnapi::kMinSdkVersionForNNAPI;
|
||||
using ::tflite::delegate::nnapi::kMinSdkVersionForNNAPI12;
|
||||
|
||||
namespace {
|
||||
|
||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
||||
const std::vector<int>& data) {
|
||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
|
||||
TfLiteIntArrayCreate(data.size()));
|
||||
std::copy(data.begin(), data.end(), result->data);
|
||||
return result;
|
||||
}
|
||||
} // namespace
|
||||
|
||||
// static
|
||||
TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
|
||||
TfLiteContext* context, TfLiteDelegate* delegate, const NnApi* nnapi,
|
||||
@ -4198,7 +4188,8 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
|
||||
auto* delegate_data = static_cast<Data*>(delegate->data_);
|
||||
// The first entry in the array is the element count
|
||||
|
||||
auto supported_nodes_int_array = BuildTfLiteIntArray(supported_nodes);
|
||||
auto supported_nodes_int_array =
|
||||
delegates::BuildTfLiteIntArray(supported_nodes);
|
||||
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
||||
context, supported_nodes_int_array.get(), params_array, num_partitions));
|
||||
// For each partition check if which nodes are actually supported by the
|
||||
@ -4231,7 +4222,7 @@ TfLiteStatus StatefulNnApiDelegate::GetNodesSupportedByAccelerator(
|
||||
// We changed the set of nodes to delegate this will create a different
|
||||
// partitioning layout.
|
||||
auto device_sup_nodes_int_array =
|
||||
BuildTfLiteIntArray(*device_supported_nodes);
|
||||
delegates::BuildTfLiteIntArray(*device_supported_nodes);
|
||||
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
||||
context, device_sup_nodes_int_array.get(), params_array,
|
||||
num_partitions));
|
||||
@ -4428,7 +4419,8 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
||||
&num_partitions, ¶ms_array, nnapi_errno));
|
||||
} else {
|
||||
nodes_to_delegate = supported_nodes;
|
||||
auto supported_nodes_int_array = BuildTfLiteIntArray(supported_nodes);
|
||||
auto supported_nodes_int_array =
|
||||
delegates::BuildTfLiteIntArray(supported_nodes);
|
||||
TF_LITE_ENSURE_STATUS(context->PreviewDelegatePartitioning(
|
||||
context, supported_nodes_int_array.get(), ¶ms_array,
|
||||
&num_partitions));
|
||||
@ -4445,7 +4437,8 @@ TfLiteStatus StatefulNnApiDelegate::DoPrepare(TfLiteContext* context,
|
||||
} else {
|
||||
// Request TFLite to partition the graph and make kernels
|
||||
// for each independent node sub set a new nnapi_delegate_kernel.
|
||||
auto nodes_to_delegate_int_array = BuildTfLiteIntArray(nodes_to_delegate);
|
||||
auto nodes_to_delegate_int_array =
|
||||
delegates::BuildTfLiteIntArray(nodes_to_delegate);
|
||||
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, nnapi_delegate_kernel, nodes_to_delegate_int_array.get(),
|
||||
delegate);
|
||||
|
@ -46,6 +46,14 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
||||
const std::vector<int>& data) {
|
||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> result(
|
||||
TfLiteIntArrayCreate(data.size()));
|
||||
std::copy(data.begin(), data.end(), result->data);
|
||||
return result;
|
||||
}
|
||||
|
||||
TfLiteStatus GraphPartitionHelper::Partition(
|
||||
std::set<std::string>* unsupported_nodes_info) {
|
||||
const auto prepare_status = PrepareSupportedNodes(unsupported_nodes_info);
|
||||
@ -148,12 +156,16 @@ TfLiteStatus FP16GraphPartitionHelper::Partition(
|
||||
}
|
||||
|
||||
std::vector<int> FP16GraphPartitionHelper::GetNodesOfFirstNLargestPartitions(
|
||||
int n) {
|
||||
int n, int min_nodes_per_partition,
|
||||
std::vector<TfLiteDelegateParams*>* partitions) {
|
||||
// We first get partitions to reduce the number of nodes to be checked in
|
||||
// deciding which dequant ops could actually be replaced. And then we
|
||||
// remap input-tensor to dequant nodes' inputs and remove those
|
||||
// to-be-reserved dequant nodes.
|
||||
auto first_nps = GetFirstNLargestPartitions(n);
|
||||
auto first_nps = GetFirstNLargestPartitions(n, min_nodes_per_partition);
|
||||
if (partitions != nullptr) {
|
||||
*partitions = first_nps;
|
||||
}
|
||||
std::vector<int> ops_to_replace;
|
||||
for (const auto p : first_nps) {
|
||||
auto nodes = p->nodes_to_replace;
|
||||
|
@ -20,6 +20,7 @@ limitations under the License.
|
||||
|
||||
#include <functional>
|
||||
#include <limits>
|
||||
#include <memory>
|
||||
#include <set>
|
||||
#include <string>
|
||||
#include <unordered_map>
|
||||
@ -40,6 +41,9 @@ TfLiteStatus CreateNewTensorWithDifferentType(TfLiteContext* context,
|
||||
TfLiteTensor** new_tensor,
|
||||
int* new_tensor_index);
|
||||
|
||||
std::unique_ptr<TfLiteIntArray, TfLiteIntArrayDeleter> BuildTfLiteIntArray(
|
||||
const std::vector<int>& data);
|
||||
|
||||
using IsNodeSupportedFn =
|
||||
std::function<bool(TfLiteContext*, TfLiteNode*, TfLiteRegistration*,
|
||||
std::string* unsupported_details)>;
|
||||
@ -134,7 +138,9 @@ class FP16GraphPartitionHelper : public GraphPartitionHelper {
|
||||
// returned. The partition is ranked according to the number of nodes.
|
||||
// TODO(b/156707497): Add this to superclass besides
|
||||
// GetFirstNLargestPartitions (one that returns partitions instead of nodes)
|
||||
std::vector<int> GetNodesOfFirstNLargestPartitions(int n);
|
||||
std::vector<int> GetNodesOfFirstNLargestPartitions(
|
||||
int n, int min_nodes_per_partition = 0,
|
||||
std::vector<TfLiteDelegateParams*>* partitions = nullptr);
|
||||
|
||||
protected:
|
||||
bool IsNodeSupported(TfLiteContext* context, TfLiteNode* node,
|
||||
|
@ -167,21 +167,42 @@ void ConvolutionOpBuilder::TransposeKernelWeights() {
|
||||
layer_->mutable_convolution()->set_isdeconvolution(true);
|
||||
}
|
||||
|
||||
auto* coreml_weights =
|
||||
layer_->mutable_convolution()->mutable_weights()->mutable_floatvalue();
|
||||
coreml_weights->Resize(NumElements(weights_), 0);
|
||||
if (weights_->type == kTfLiteFloat32) {
|
||||
auto* coreml_weights =
|
||||
layer_->mutable_convolution()->mutable_weights()->mutable_floatvalue();
|
||||
coreml_weights->Resize(NumElements(weights_), 0);
|
||||
|
||||
optimized_ops::Transpose<float>(params, tfl_shape, weights_->data.f,
|
||||
coreml_shape, coreml_weights->mutable_data());
|
||||
optimized_ops::Transpose<float>(params, tfl_shape, weights_->data.f,
|
||||
coreml_shape,
|
||||
coreml_weights->mutable_data());
|
||||
} else if (weights_->type == kTfLiteFloat16) {
|
||||
auto* coreml_weights = layer_->mutable_convolution()
|
||||
->mutable_weights()
|
||||
->mutable_float16value();
|
||||
// float16value has type of bytes (std::string)
|
||||
coreml_weights->resize(weights_->bytes, 0);
|
||||
|
||||
optimized_ops::Transpose<uint16_t>(
|
||||
params, tfl_shape, reinterpret_cast<uint16_t*>(weights_->data.raw),
|
||||
coreml_shape, reinterpret_cast<uint16_t*>(coreml_weights->data()));
|
||||
}
|
||||
}
|
||||
|
||||
void ConvolutionOpBuilder::FillCoreMLBias() {
|
||||
if (bias_ != nullptr) {
|
||||
layer_->mutable_convolution()->set_hasbias(true);
|
||||
std::copy(bias_->data.f, bias_->data.f + NumElements(bias_->dims),
|
||||
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_convolution()
|
||||
->mutable_bias()
|
||||
->mutable_floatvalue()));
|
||||
if (bias_->type == kTfLiteFloat32) {
|
||||
std::copy(bias_->data.f, bias_->data.f + NumElements(bias_->dims),
|
||||
google::protobuf::RepeatedFieldBackInserter(layer_->mutable_convolution()
|
||||
->mutable_bias()
|
||||
->mutable_floatvalue()));
|
||||
} else if (bias_->type == kTfLiteFloat16) {
|
||||
// float16value has type of bytes (std::string)
|
||||
layer_->mutable_convolution()
|
||||
->mutable_bias()
|
||||
->mutable_float16value()
|
||||
->assign(bias_->data.raw, bias_->bytes);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
|
@ -0,0 +1,39 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#include "tensorflow/lite/experimental/delegates/coreml/builders/dummy_op_builder.h"
|
||||
|
||||
#include "tensorflow/lite/experimental/delegates/coreml/builders/op_factory.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
namespace coreml {
|
||||
|
||||
CoreML::Specification::NeuralNetworkLayer* DummyOpBuilder::Build() {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
const char* DummyOpBuilder::DebugName() { return "Dummy OpBuilder"; }
|
||||
|
||||
TfLiteStatus DummyOpBuilder::PopulateSubgraph(TfLiteContext* context) {
|
||||
return kTfLiteOk;
|
||||
}
|
||||
|
||||
OpBuilder* CreateDummyOpBuilder(GraphBuilder* graph_builder) {
|
||||
return new DummyOpBuilder(graph_builder);
|
||||
}
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
@ -0,0 +1,41 @@
|
||||
/* Copyright 2020 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
#ifndef TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_DUMMY_OP_BUILDER_H_
|
||||
#define TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_DUMMY_OP_BUILDER_H_
|
||||
|
||||
#include "tensorflow/lite/builtin_ops.h"
|
||||
#include "tensorflow/lite/c/builtin_op_data.h"
|
||||
#include "tensorflow/lite/c/common.h"
|
||||
#include "tensorflow/lite/experimental/delegates/coreml/builders/op_builder.h"
|
||||
|
||||
namespace tflite {
|
||||
namespace delegates {
|
||||
namespace coreml {
|
||||
|
||||
// Dummy Opbuilder for nodes that are claimed but not used. ex) FP16 dequantize
|
||||
class DummyOpBuilder : public OpBuilder {
|
||||
public:
|
||||
explicit DummyOpBuilder(GraphBuilder* graph_builder)
|
||||
: OpBuilder(graph_builder) {}
|
||||
CoreML::Specification::NeuralNetworkLayer* Build() override;
|
||||
TfLiteStatus PopulateSubgraph(TfLiteContext* context) override;
|
||||
const char* DebugName() override;
|
||||
};
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace delegates
|
||||
} // namespace tflite
|
||||
|
||||
#endif // THIRD_PARTY_TENSORFLOW_LITE_EXPERIMENTAL_DELEGATES_COREML_BUILDERS_DUMMY_OP_BUILDER_H_
|
@ -24,7 +24,6 @@ namespace tflite {
|
||||
namespace delegates {
|
||||
namespace coreml {
|
||||
OpBuilder* GraphBuilder::AddBuilder(int builtin_code, const TfLiteNode* node) {
|
||||
// Follow the ordering of TfLiteBuiltinOperator enum.
|
||||
switch (builtin_code) {
|
||||
case kTfLiteBuiltinAdd:
|
||||
return AddBuilder(CreateAddOpBuilder, node);
|
||||
@ -36,6 +35,11 @@ OpBuilder* GraphBuilder::AddBuilder(int builtin_code, const TfLiteNode* node) {
|
||||
return AddBuilder(CreateConvolutionOpBuilder, node);
|
||||
case kTfLiteBuiltinDepthwiseConv2d:
|
||||
return AddBuilder(CreateDepthwiseConvolutionOpBuilder, node);
|
||||
// TODO(b/141490853): Add proper dequantize OpBuilder for int8/uint8 inputs.
|
||||
case kTfLiteBuiltinDequantize:
|
||||
// FP16 dequantize is claimed by the delegate to prevent them from running
|
||||
// on CPU, but don't need to be excuted on the Core ML delegate either.
|
||||
return AddBuilder(CreateDummyOpBuilder, node);
|
||||
case kTfLiteBuiltinFullyConnected:
|
||||
return AddBuilder(CreateFullyConnectedOpBuilder, node);
|
||||
case kTfLiteBuiltinLogistic:
|
||||
|
@ -44,6 +44,8 @@ OpBuilder* CreateTransposeConvolutionOpBuilder(GraphBuilder* graph_builder);
|
||||
|
||||
OpBuilder* CreateActivationLayerBuilder(GraphBuilder* graph_builder);
|
||||
OpBuilder* CreateThresholdLayerBuilder(GraphBuilder* graph_builder);
|
||||
// Dummy Opbuilder for nodes that are claimed but not used. ex) FP16 dequantize
|
||||
OpBuilder* CreateDummyOpBuilder(GraphBuilder* graph_builder);
|
||||
|
||||
} // namespace coreml
|
||||
} // namespace delegates
|
||||
|
@ -193,8 +193,7 @@ TfLiteRegistration GetCoreMlKernelRegistration() {
|
||||
kernel_registration.init = [](TfLiteContext* context, const char* buffer,
|
||||
size_t length) -> void* {
|
||||
const auto* params = reinterpret_cast<const TfLiteDelegateParams*>(buffer);
|
||||
const auto* coreml_options =
|
||||
(reinterpret_cast<CoreMlDelegate*>(params->delegate))->params();
|
||||
const auto* coreml_options = (reinterpret_cast<CoreMlDelegate*>(params->delegate))->params();
|
||||
CoreMlDelegateKernel* coreml_kernel = new CoreMlDelegateKernel(coreml_options->coreml_version);
|
||||
if (coreml_kernel->Init(context, params) != kTfLiteOk) {
|
||||
delete coreml_kernel;
|
||||
@ -231,31 +230,19 @@ TfLiteStatus DelegatePrepare(TfLiteContext* context, TfLiteDelegate* delegate) {
|
||||
return IsNodeSupportedByDelegate(registration, node, context, params);
|
||||
};
|
||||
|
||||
delegates::GraphPartitionHelper helper(context, node_supported_fn);
|
||||
TF_LITE_ENSURE_STATUS(helper.Partition(nullptr));
|
||||
delegates::FP16GraphPartitionHelper partition_helper(context, node_supported_fn);
|
||||
TF_LITE_ENSURE_STATUS(partition_helper.Partition(nullptr));
|
||||
|
||||
const auto delegate_partitions = helper.GetFirstNLargestPartitions(
|
||||
params->max_delegated_partitions, params->min_nodes_per_partition);
|
||||
|
||||
// To avoid creating a new TfLiteIntArray and free it later, we reserve one
|
||||
// element to represent TfLiteIntArray.size which is the 1st element of
|
||||
// TfLiteIntArray C struct.
|
||||
std::vector<int> supported_nodes(1);
|
||||
for (const auto partition : delegate_partitions) {
|
||||
auto nodes = TfLiteIntArrayView(partition->nodes_to_replace);
|
||||
supported_nodes.insert(supported_nodes.end(), nodes.begin(), nodes.end());
|
||||
}
|
||||
|
||||
// Set first element to the number of nodes to replace.
|
||||
supported_nodes[0] = supported_nodes.size() - 1;
|
||||
std::vector<TfLiteDelegateParams*> partitions;
|
||||
std::vector<int> delegated_nodes = partition_helper.GetNodesOfFirstNLargestPartitions(
|
||||
params->max_delegated_partitions, params->min_nodes_per_partition, &partitions);
|
||||
TFLITE_LOG_PROD(tflite::TFLITE_LOG_INFO,
|
||||
"CoreML delegate: %d nodes delegated out of %d nodes, "
|
||||
"with %d partitions.\n",
|
||||
supported_nodes[0], helper.num_total_nodes(), delegate_partitions.size());
|
||||
|
||||
delegated_nodes.size(), partition_helper.num_total_nodes(), partitions.size());
|
||||
return context->ReplaceNodeSubsetsWithDelegateKernels(
|
||||
context, GetCoreMlKernelRegistration(),
|
||||
reinterpret_cast<TfLiteIntArray*>(supported_nodes.data()), delegate);
|
||||
context, GetCoreMlKernelRegistration(), delegates::BuildTfLiteIntArray(delegated_nodes).get(),
|
||||
delegate);
|
||||
}
|
||||
|
||||
TfLiteDelegate* CreateCoreMlDelegate(const TfLiteCoreMlDelegateOptions* options) {
|
||||
|
@ -830,13 +830,14 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
in_call = call_context.in_call
|
||||
input_list = nest.flatten(inputs)
|
||||
|
||||
# We will attempt to trace in a graph if & only if inputs are symbolic.
|
||||
# This is always the case when tracing a function. It can also be the case
|
||||
# when running eagerly if any input can be traced back to `keras.Input()`
|
||||
# (when building models using the functional API).
|
||||
build_graph = tf_utils.are_all_symbolic_tensors(input_list) or (
|
||||
any(map(tf_utils.is_symbolic_tensor, nest.flatten(
|
||||
[input_list, args, kwargs]))) and context.executing_eagerly())
|
||||
# We will attempt to build a TF graph if & only if all inputs are symbolic.
|
||||
# This is always the case in graph mode. It can also be the case in eager
|
||||
# mode when all inputs can be traced back to `keras.Input()` (when building
|
||||
# models using the functional API).
|
||||
# TODO(kaftan): make this not special case inputs. Instead
|
||||
# build a functional api model if *any* *arg or **kwarg is symbolic,
|
||||
# even if part of the data structure in that arg is not symbolic.
|
||||
build_graph = tf_utils.are_all_symbolic_tensors(input_list)
|
||||
|
||||
# Accept NumPy and scalar inputs by converting to Tensors.
|
||||
if any(isinstance(x, (np.ndarray, float, int)) for x in input_list):
|
||||
@ -889,14 +890,11 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
'training', training_value, args, kwargs)
|
||||
training_arg_passed_by_framework = True
|
||||
|
||||
# Turn inputs into TF op layers if necessary.
|
||||
# This process is fragile and prone to bad interactions with inputs
|
||||
# when calling nested layers with tf.functions floating around,
|
||||
# and with nonsymbolic tensors.
|
||||
# So, we limit it to the
|
||||
# case where *all* inputs in the first arg are symbolic.
|
||||
if (tf_utils.are_all_symbolic_tensors(input_list)
|
||||
and base_layer_utils.needs_keras_history(inputs)):
|
||||
# Only create Keras history if at least one tensor originates from a
|
||||
# `keras.Input`. Otherwise this Layer may be being used outside the Keras
|
||||
# framework.
|
||||
# TODO(kaftan): make this not special case inputs
|
||||
if build_graph and base_layer_utils.needs_keras_history(inputs):
|
||||
base_layer_utils.create_keras_history(inputs)
|
||||
|
||||
with call_context.enter(self, inputs, build_graph, training_value):
|
||||
@ -970,12 +968,8 @@ class Layer(module.Module, version_utils.LayerVersionSelector):
|
||||
raise ValueError('A layer\'s `call` method should return a '
|
||||
'Tensor or a list of Tensors, not None '
|
||||
'(layer: ' + self.name + ').')
|
||||
# We configure connectivity metadata if all inputs in the first
|
||||
# arg have keras history, or if we're actively building the
|
||||
# functional api outside of any outer keras model.
|
||||
if base_layer_utils.have_all_keras_metadata(inputs) or (
|
||||
context.executing_eagerly() and
|
||||
base_layer_utils.have_any_keras_metadata(inputs, args, kwargs)):
|
||||
# TODO(kaftan): This should be 'any' and check all args
|
||||
if base_layer_utils.have_all_keras_metadata(inputs):
|
||||
if training_arg_passed_by_framework:
|
||||
args, kwargs = self._set_call_arg_value(
|
||||
'training', None, args, kwargs, pop_kwarg_if_none=True)
|
||||
|
@ -165,10 +165,6 @@ def have_all_keras_metadata(tensors):
|
||||
return all(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
|
||||
|
||||
|
||||
def have_any_keras_metadata(*tensors):
|
||||
return any(hasattr(x, '_keras_history') for x in nest.flatten(tensors))
|
||||
|
||||
|
||||
def generate_placeholders_from_shape(shape):
|
||||
return array_ops.placeholder(shape=shape, dtype=backend.floatx())
|
||||
|
||||
@ -218,10 +214,7 @@ def _create_keras_history_helper(tensors, processed_ops, created_layers):
|
||||
for tensor in tensor_list:
|
||||
if getattr(tensor, '_keras_history', None) is not None:
|
||||
continue
|
||||
try:
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
except AttributeError:
|
||||
continue
|
||||
op = tensor.op # The Op that created this Tensor.
|
||||
if op not in processed_ops:
|
||||
if op.type.startswith('Sparse'):
|
||||
lambda_example = """
|
||||
@ -399,10 +392,7 @@ def mark_checked(tensors):
|
||||
"""
|
||||
|
||||
def _mark_checked(tensor):
|
||||
try:
|
||||
tensor._keras_history_checked = True # pylint: disable=protected-access
|
||||
except AttributeError:
|
||||
pass
|
||||
tensor._keras_history_checked = True # pylint: disable=protected-access
|
||||
|
||||
nest.map_structure(_mark_checked, tensors)
|
||||
|
||||
|
@ -32,7 +32,6 @@ from tensorflow.python.keras import backend
|
||||
from tensorflow.python.keras.engine import base_layer
|
||||
from tensorflow.python.keras.engine import base_layer_utils
|
||||
from tensorflow.python.keras.engine import input_layer as input_layer_module
|
||||
from tensorflow.python.keras.engine import node as node_module
|
||||
from tensorflow.python.keras.engine import training as training_lib
|
||||
from tensorflow.python.keras.engine import training_utils
|
||||
from tensorflow.python.keras.saving.saved_model import network_serialization
|
||||
@ -1112,28 +1111,19 @@ def reconstruct_from_config(config, custom_objects=None, created_layers=None):
|
||||
kwargs = {}
|
||||
elif len(input_data) == 4:
|
||||
kwargs = input_data[3]
|
||||
try:
|
||||
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
|
||||
except IndexError:
|
||||
# Happens if keras tensors in kwargs are still unprocessed
|
||||
add_unprocessed_node(layer, node_data)
|
||||
return
|
||||
kwargs = _deserialize_keras_tensors(kwargs, created_layers)
|
||||
else:
|
||||
raise ValueError('Improperly formatted model config.')
|
||||
|
||||
if inbound_layer_name != node_module._CONSTANT_VALUE:
|
||||
inbound_layer = created_layers[inbound_layer_name]
|
||||
inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
|
||||
inbound_layer = created_layers[inbound_layer_name]
|
||||
inbound_node_index = get_node_index(inbound_layer, inbound_node_index)
|
||||
|
||||
if inbound_node_index is None:
|
||||
add_unprocessed_node(layer, node_data)
|
||||
return
|
||||
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
||||
input_tensors.append(
|
||||
nest.flatten(inbound_node.outputs)[inbound_tensor_index])
|
||||
else:
|
||||
# We received a constant w/ no Keras history attached
|
||||
input_tensors.append(inbound_tensor_index)
|
||||
if inbound_node_index is None:
|
||||
add_unprocessed_node(layer, node_data)
|
||||
return
|
||||
inbound_node = inbound_layer._inbound_nodes[inbound_node_index]
|
||||
input_tensors.append(
|
||||
nest.flatten(inbound_node.outputs)[inbound_tensor_index])
|
||||
input_tensors = nest.pack_sequence_as(node_data, input_tensors)
|
||||
# Call layer on its inputs, thus creating the node
|
||||
# and building the layer if needed.
|
||||
|
@ -964,43 +964,6 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_kwarg_derived_from_keras_layer_and_first_arg_is_constant(self):
|
||||
|
||||
class MaybeAdd(layers.Layer):
|
||||
|
||||
def call(self, x1, x2=None):
|
||||
if x2 is not None:
|
||||
return x1 + x2
|
||||
return x1
|
||||
|
||||
input2 = input_layer_lib.Input(10)
|
||||
outputs = MaybeAdd()(3., x2=input2)
|
||||
model = training_lib.Model([input2], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=7 * np.ones((10, 10)),
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'MaybeAdd': MaybeAdd})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=7 * np.ones((10, 10)),
|
||||
y=10 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_composite_call_kwarg_derived_from_keras_layer(self):
|
||||
|
||||
@ -1042,56 +1005,6 @@ class NetworkConstructionTest(keras_parameterized.TestCase):
|
||||
# Check that second input was correctly added to first.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations(mode='eager'))
|
||||
def test_call_some_not_all_nested_in_first_arg_derived_from_keras_layer(self):
|
||||
# This functionality is unsupported in v1 graphs
|
||||
|
||||
class AddAll(layers.Layer):
|
||||
|
||||
def call(self, x1_x2, x3):
|
||||
x1, x2 = x1_x2
|
||||
out = x1 + x2
|
||||
if x3 is not None:
|
||||
for t in x3.values():
|
||||
out += t
|
||||
return out
|
||||
|
||||
input1 = input_layer_lib.Input(10)
|
||||
input2 = input_layer_lib.Input(10)
|
||||
input3 = input_layer_lib.Input(10)
|
||||
|
||||
outputs = AddAll()(
|
||||
[input1, 4 * array_ops.ones((1, 10))],
|
||||
x3={
|
||||
'a': input2,
|
||||
'b': input3,
|
||||
'c': 5 * array_ops.ones((1, 10))
|
||||
})
|
||||
model = training_lib.Model([input1, input2, input3], outputs)
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
|
||||
y=15 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that all inputs were correctly added.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
model = training_lib.Model.from_config(
|
||||
model.get_config(), custom_objects={'AddAll': AddAll})
|
||||
model.compile(
|
||||
'sgd',
|
||||
'mse',
|
||||
run_eagerly=testing_utils.should_run_eagerly())
|
||||
history = model.fit(
|
||||
x=[np.ones((10, 10)), 2 * np.ones((10, 10)), 3 * np.ones((10, 10))],
|
||||
y=15 * np.ones((10, 10)),
|
||||
batch_size=2)
|
||||
# Check that all inputs were correctly added.
|
||||
self.assertEqual(history.history['loss'][0], 0.0)
|
||||
|
||||
@combinations.generate(combinations.keras_mode_combinations())
|
||||
def test_call_nested_arg_derived_from_keras_layer(self):
|
||||
|
||||
|
@ -32,8 +32,6 @@ from tensorflow.python.platform import tf_logging as logging
|
||||
from tensorflow.python.util import nest
|
||||
from tensorflow.python.util import serialization
|
||||
|
||||
_CONSTANT_VALUE = '_CONSTANT_VALUE'
|
||||
|
||||
|
||||
class Node(object):
|
||||
"""A `Node` describes the connectivity between two layers.
|
||||
@ -183,14 +181,11 @@ class Node(object):
|
||||
# `kwargs` is added to each Tensor in the first arg. This should be
|
||||
# changed in a future version of the serialization format.
|
||||
def serialize_first_arg_tensor(t):
|
||||
if is_keras_tensor(t):
|
||||
kh = t._keras_history
|
||||
node_index = kh.node_index
|
||||
node_key = make_node_key(kh.layer.name, node_index)
|
||||
new_node_index = node_conversion_map.get(node_key, 0)
|
||||
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
|
||||
else:
|
||||
data = [_CONSTANT_VALUE, -1, _serialize_keras_tensor(t), kwargs]
|
||||
kh = t._keras_history
|
||||
node_index = kh.node_index
|
||||
node_key = make_node_key(kh.layer.name, node_index)
|
||||
new_node_index = node_conversion_map.get(node_key, 0)
|
||||
data = [kh.layer.name, new_node_index, kh.tensor_index, kwargs]
|
||||
return tf_utils.ListWrapper(data)
|
||||
|
||||
data = nest.map_structure(serialize_first_arg_tensor, inputs)
|
||||
|
@ -655,8 +655,8 @@ def tf_repositories(path_prefix = "", tf_repo_name = ""):
|
||||
)
|
||||
|
||||
# Check out LLVM and MLIR from llvm-project.
|
||||
LLVM_COMMIT = "cf86a234ba86acf0bb875e21d27833be36e08be4"
|
||||
LLVM_SHA256 = "5375bdcdabd4886ab86eddfddef6e21dbc3cac9df67af7d3c44fadb527f74e25"
|
||||
LLVM_COMMIT = "b726d071b4aa46004228fc38ee5bfd167f999bfe"
|
||||
LLVM_SHA256 = "d7e67036dc89906cb2f80df7b0b7de6344d86eddf6e98bb4d01a578242889a73"
|
||||
LLVM_URLS = [
|
||||
"https://storage.googleapis.com/mirror.tensorflow.org/github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
"https://github.com/llvm/llvm-project/archive/{commit}.tar.gz".format(commit = LLVM_COMMIT),
|
||||
|
37
third_party/mlir/BUILD
vendored
37
third_party/mlir/BUILD
vendored
@ -1176,28 +1176,6 @@ cc_library(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "GPURuntimeTransforms",
|
||||
srcs = [
|
||||
"lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp",
|
||||
"lib/Conversion/PassDetail.h",
|
||||
],
|
||||
hdrs = [
|
||||
"include/mlir/Conversion/GPUCommon/GPUCommonPass.h",
|
||||
],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ConversionPassIncGen",
|
||||
":GPUDialect",
|
||||
":IR",
|
||||
":LLVMDialect",
|
||||
":Pass",
|
||||
":Support",
|
||||
"@llvm-project//llvm:core",
|
||||
"@llvm-project//llvm:support",
|
||||
],
|
||||
)
|
||||
|
||||
gentbl(
|
||||
name = "GPUToNVVMGen",
|
||||
strip_include_prefix = "lib/Conversion/GPUToNVVM",
|
||||
@ -1307,12 +1285,13 @@ cc_library(
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "GPUToCUDATransforms",
|
||||
name = "GPUToGPURuntimeTransforms",
|
||||
srcs = [
|
||||
"lib/Conversion/GPUToCUDA/ConvertKernelFuncToCubin.cpp",
|
||||
"lib/Conversion/GPUCommon/ConvertKernelFuncToBlob.cpp",
|
||||
"lib/Conversion/GPUCommon/ConvertLaunchFuncToRuntimeCalls.cpp",
|
||||
"lib/Conversion/PassDetail.h",
|
||||
],
|
||||
hdrs = ["include/mlir/Conversion/GPUToCUDA/GPUToCUDAPass.h"],
|
||||
hdrs = ["include/mlir/Conversion/GPUCommon/GPUCommonPass.h"],
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":ConversionPassIncGen",
|
||||
@ -2490,7 +2469,7 @@ cc_library(
|
||||
includes = ["include"],
|
||||
deps = [
|
||||
":Analysis",
|
||||
":GPURuntimeTransforms",
|
||||
":GPUToGPURuntimeTransforms",
|
||||
":GPUToNVVMTransforms",
|
||||
":GPUToROCDLTransforms",
|
||||
":GPUToSPIRVTransforms",
|
||||
@ -2570,8 +2549,7 @@ cc_library(
|
||||
":ConversionPassIncGen",
|
||||
":GPUDialect",
|
||||
":GPUPassIncGen",
|
||||
":GPURuntimeTransforms",
|
||||
":GPUToCUDATransforms",
|
||||
":GPUToGPURuntimeTransforms",
|
||||
":GPUToNVVMTransforms",
|
||||
":GPUToROCDLTransforms",
|
||||
":GPUToSPIRVTransforms",
|
||||
@ -2776,7 +2754,7 @@ cc_binary(
|
||||
":AllPassesAndDialectsNoRegistration",
|
||||
":ExecutionEngineUtils",
|
||||
":GPUDialect",
|
||||
":GPURuntimeTransforms",
|
||||
":GPUToGPURuntimeTransforms",
|
||||
":GPUToNVVMTransforms",
|
||||
":GPUToROCDLTransforms",
|
||||
":GPUTransforms",
|
||||
@ -2786,6 +2764,7 @@ cc_binary(
|
||||
":MlirJitRunner",
|
||||
":NVVMDialect",
|
||||
":Pass",
|
||||
":TargetNVVMIR",
|
||||
":Transforms",
|
||||
"//devtools/build/runtime:get_runfiles_dir",
|
||||
"//third_party/gpus/cuda:cuda_headers",
|
||||
|
4
third_party/mlir/test.BUILD
vendored
4
third_party/mlir/test.BUILD
vendored
@ -158,7 +158,7 @@ cc_library(
|
||||
"@llvm-project//mlir:Analysis",
|
||||
"@llvm-project//mlir:EDSC",
|
||||
"@llvm-project//mlir:GPUDialect",
|
||||
"@llvm-project//mlir:GPUToCUDATransforms",
|
||||
"@llvm-project//mlir:GPUToGPURuntimeTransforms",
|
||||
"@llvm-project//mlir:GPUTransforms",
|
||||
"@llvm-project//mlir:IR",
|
||||
"@llvm-project//mlir:LinalgOps",
|
||||
@ -167,6 +167,8 @@ cc_library(
|
||||
"@llvm-project//mlir:SCFDialect",
|
||||
"@llvm-project//mlir:StandardOps",
|
||||
"@llvm-project//mlir:Support",
|
||||
"@llvm-project//mlir:TargetNVVMIR",
|
||||
"@llvm-project//mlir:TargetROCDLIR",
|
||||
"@llvm-project//mlir:TransformUtils",
|
||||
"@llvm-project//mlir:Transforms",
|
||||
"@llvm-project//mlir:VectorOps",
|
||||
|
Loading…
Reference in New Issue
Block a user