Merge changes from github.

PiperOrigin-RevId: 177526301
This commit is contained in:
Sourabh Bajaj 2017-11-30 16:37:11 -08:00 committed by TensorFlower Gardener
parent 0438ac79bd
commit b2db981a67
171 changed files with 4447 additions and 579 deletions

5
.gitignore vendored
View File

@ -22,3 +22,8 @@ Pods
Podfile.lock
*.pbxproj
*.xcworkspacedata
/tensorflow/contrib/lite/downloads/**
/tensorflow/contrib/lite/gen/**
/tensorflow/contrib/lite/examples/ios/simple/data/*.txt
/tensorflow/contrib/lite/examples/ios/simple/data/*.tflite
xcuserdata/**

View File

@ -346,9 +346,9 @@ void XlaOpKernelContext::SetConstantOutput(int index, const Tensor& constant) {
}
void XlaOpKernelContext::SetInvalidOutput(int index) {
const TensorShape shape;
Tensor* output = nullptr;
OP_REQUIRES_OK(context_, context_->allocate_output(index, shape, &output));
OP_REQUIRES_OK(context_,
context_->allocate_output(index, TensorShape({}), &output));
XlaExpression* expression = CastExpressionFromUninitializedTensor(output);
xla::ComputationDataHandle handle;
handle.set_handle(0);

View File

@ -175,6 +175,7 @@ cc_library(
":types",
":xla_data_proto",
"//tensorflow/core:lib",
"//tensorflow/core:ptr_util",
],
)

View File

@ -16,7 +16,8 @@ limitations under the License.
#ifndef TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
#define TENSORFLOW_COMPILER_XLA_PTR_UTIL_H_
// Utility functions for pointers.
// As this was moved to tensorflow/core/util, provide indirections here to
// maintain current functionality of the library.
#include <stddef.h>
@ -24,55 +25,27 @@ limitations under the License.
#include <type_traits>
#include <utility>
#include "tensorflow/core/util/ptr_util.h"
namespace xla {
namespace internal {
// Trait to select overloads and return types for MakeUnique.
template <typename T>
struct MakeUniqueResult {
using scalar = std::unique_ptr<T>;
};
template <typename T>
struct MakeUniqueResult<T[]> {
using array = std::unique_ptr<T[]>;
};
template <typename T, size_t N>
struct MakeUniqueResult<T[N]> {
using invalid = void;
};
} // namespace internal
// Transfers ownership of a raw pointer to a std::unique_ptr of deduced type.
// Example:
// X* NewX(int, int);
// auto x = WrapUnique(NewX(1, 2)); // 'x' is std::unique_ptr<X>.
//
// WrapUnique is useful for capturing the output of a raw pointer factory.
// However, prefer 'MakeUnique<T>(args...) over 'WrapUnique(new T(args...))'.
// auto x = WrapUnique(new X(1, 2)); // works, but nonideal.
// auto x = MakeUnique<X>(1, 2); // safer, standard, avoids raw 'new'.
//
// Note: Cannot wrap pointers to array of unknown bound (i.e. U(*)[]).
template <typename T>
std::unique_ptr<T> WrapUnique(T* ptr) {
static_assert(!std::is_array<T>::value || std::extent<T>::value != 0,
"types T[0] or T[] are unsupported");
return std::unique_ptr<T>(ptr);
return tensorflow::WrapUnique<T>(ptr);
}
template <typename T, typename... Args>
typename internal::MakeUniqueResult<T>::scalar MakeUnique(Args&&... args) {
return std::unique_ptr<T>(new T(std::forward<Args>(args)...));
typename tensorflow::helper::MakeUniqueResult<T>::scalar MakeUnique(
Args&&... args) {
return tensorflow::MakeUnique<T, Args...>(std::forward<Args>(args)...);
}
// Overload for array of unknown bound.
// The allocation of arrays needs to use the array form of new,
// and cannot take element constructor arguments.
template <typename T>
typename internal::MakeUniqueResult<T>::array MakeUnique(size_t n) {
return std::unique_ptr<T>(new typename std::remove_extent<T>::type[n]());
typename tensorflow::helper::MakeUniqueResult<T>::array MakeUnique(size_t n) {
return tensorflow::MakeUnique<T>(n);
}
} // namespace xla

View File

@ -85,7 +85,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunBufferAssignment(HloModule* module,
int64 alignment = 1) {
return BufferAssigner::Run(
module, MakeUnique<DependencyHloOrdering>(module),
module, xla::MakeUnique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; })
.ConsumeValueOrDie();
@ -94,7 +94,7 @@ class BufferAssignmentTest : public HloTestBase {
std::unique_ptr<BufferAssignment> RunColoredBufferAssignment(
HloModule* module, BufferLiveness::Colorer colorer, int64 alignment = 1) {
return BufferAssigner::Run(
module, MakeUnique<DependencyHloOrdering>(module),
module, xla::MakeUnique<DependencyHloOrdering>(module),
backend().compiler()->BufferSizeBytesFunction(),
[alignment](LogicalBuffer::Color) { return alignment; }, false,
std::move(colorer))
@ -1451,7 +1451,7 @@ class WhileBufferAssignmentTest : public HloTestBase {
auto sequence =
CreateMemoryMinimizingSequence(*module, ByteSizeOf).ConsumeValueOrDie();
return BufferAssigner::Run(
module, MakeUnique<SequentialHloOrdering>(module, sequence),
module, xla::MakeUnique<SequentialHloOrdering>(module, sequence),
ByteSizeOf,
[alignment](LogicalBuffer::Color) { return alignment; })
.ConsumeValueOrDie();
@ -1472,7 +1472,7 @@ static void RunCopyInsertion(HloModule* module) {
}
TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
auto module = MakeUnique<HloModule>(TestName());
auto module = xla::MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
@ -1529,7 +1529,7 @@ TEST_F(WhileBufferAssignmentTest, TwoForwardWhileLoops) {
}
TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
auto module = MakeUnique<HloModule>(TestName());
auto module = xla::MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(
@ -1574,7 +1574,7 @@ TEST_F(WhileBufferAssignmentTest, OneForwardBackwardWhileLoopSet) {
}
TEST_F(BufferAssignmentTest, TwoCalls) {
auto module = MakeUnique<HloModule>(TestName());
auto module = xla::MakeUnique<HloModule>(TestName());
Shape r0f32 = ShapeUtil::MakeShape(xla::F32, {});
HloComputation* sub_computation;
{
@ -1639,7 +1639,7 @@ static bool IsPostOrderTraversal(
}
TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto module = MakeUnique<HloModule>(TestName());
auto module = xla::MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder(TestName());
auto zero = builder.AddInstruction(
@ -1710,15 +1710,15 @@ TEST_F(WhileBufferAssignmentTest, WhileLoopsInterferingResultRange) {
auto assignment =
BufferAssigner::Run(
module.get(),
MakeUnique<SequentialHloOrdering>(module.get(), sequence), ByteSizeOf,
[](LogicalBuffer::Color) { return 1; })
xla::MakeUnique<SequentialHloOrdering>(module.get(), sequence),
ByteSizeOf, [](LogicalBuffer::Color) { return 1; })
.ConsumeValueOrDie();
EXPECT_TRUE(BuffersDistinct({while0}, {while1}, *assignment));
}
TEST_F(WhileBufferAssignmentTest, WhilesDontShareEntryParamIfLiveOut) {
auto module = MakeUnique<HloModule>(TestName());
auto module = xla::MakeUnique<HloModule>(TestName());
auto builder = HloComputation::Builder("entry");
auto input0 = builder.AddInstruction(

View File

@ -120,7 +120,7 @@ TEST_F(BufferLivenessTest, ElementwiseChain) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, negate));
@ -167,10 +167,10 @@ TEST_F(BufferLivenessTest, MultipleEntryParameters_Sequential) {
SequentialHloOrdering::HloModuleSequence sequence;
sequence.insert({entry, {param0, negate, param1, exp, add}});
auto liveness = BufferLiveness::Run(
module.get(),
MakeUnique<SequentialHloOrdering>(module.get(), sequence))
.ConsumeValueOrDie();
auto liveness =
BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
module.get(), sequence))
.ConsumeValueOrDie();
// Entry parameters interfere as if they are defined simultaneously at
// the very beginning.
@ -216,7 +216,7 @@ TEST_F(BufferLivenessTest, NonElementwiseOperand) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_FALSE(InstructionsMayInterfere(*liveness, param, exp));
@ -250,7 +250,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffers) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
EXPECT_TRUE(InstructionsMayInterfere(*liveness, param, negate));
@ -294,7 +294,7 @@ TEST_F(BufferLivenessTest, OverlappedBuffersSequentialOrder) {
std::vector<const HloInstruction*> order = {param, negate, exp, add};
module_sequence.emplace(computation, order);
auto liveness =
BufferLiveness::Run(module.get(), MakeUnique<SequentialHloOrdering>(
BufferLiveness::Run(module.get(), xla::MakeUnique<SequentialHloOrdering>(
module.get(), module_sequence))
.ConsumeValueOrDie();
@ -334,7 +334,7 @@ TEST_F(BufferLivenessTest, TupleLiveOut) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// All buffers should be live out except the param
@ -370,7 +370,7 @@ TEST_F(BufferLivenessTest, EmbeddedComputation) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Buffers in different computations should always interfere.
@ -409,7 +409,7 @@ TEST_F(BufferLivenessTest, TupleConstantLiveOut) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Only the element buffers of the tuple constant which are pointed to by
@ -474,7 +474,7 @@ TEST_F(BufferLivenessTest, IndependentTupleElements) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@ -536,7 +536,7 @@ TEST_F(BufferLivenessTest, DependentTupleElements) {
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// We compare tuple element pairs that are input/output to the computation:
@ -624,8 +624,8 @@ class FusedDynamicUpdateSliceLivenessTest : public BufferLivenessTest {
// Run BufferLiveness on 'module'.
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
BufferLiveness::Run(
module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.
@ -736,8 +736,8 @@ class DynamicUpdateSliceLivenessTest : public BufferLivenessTest {
module->AddEmbeddedComputation(builder.Build());
// Run BufferLiveness on 'module'.
auto liveness =
BufferLiveness::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()))
BufferLiveness::Run(
module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()))
.ConsumeValueOrDie();
// Return whether or not buffers interference is detected between
// 'tuple_param0' and 'tuple_root' at shape index '{1}'.

View File

@ -469,11 +469,11 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
&pre_optimization_ir_hook, &post_optimization_ir_hook));
// Compile must be thread-safe so create a new LLVM context for the module.
auto llvm_context = MakeUnique<llvm::LLVMContext>();
auto llvm_context = xla::MakeUnique<llvm::LLVMContext>();
auto llvm_module =
MakeUnique<llvm::Module>("__compute_module", *llvm_context);
xla::MakeUnique<llvm::Module>("__compute_module", *llvm_context);
auto jit = MakeUnique<SimpleOrcJIT>(
auto jit = xla::MakeUnique<SimpleOrcJIT>(
CompilerTargetOptions(module->config()),
CodeGenOptLevel(module->config()),
options::OptimizeForSizeRequested(module->config()),
@ -528,9 +528,9 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// uses data dependencies for determining order.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(module.get(),
MakeUnique<DependencyHloOrdering>(module.get()),
BufferSizeBytesFunction(), memory_alignment));
BufferAssigner::Run(
module.get(), xla::MakeUnique<DependencyHloOrdering>(module.get()),
BufferSizeBytesFunction(), memory_alignment));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@ -557,7 +557,7 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
const void* data = instruction->literal().InternalData();
int64 size = CpuExecutable::ShapeSizeBytes(instruction->shape());
auto iter = aligned_constants.emplace(
instruction, MakeUnique<unsigned char[]>(size));
instruction, xla::MakeUnique<unsigned char[]>(size));
CHECK_EQ(iter.second, true);
unsigned char* aligned_data = iter.first->second.get();
memcpy(aligned_data, data, size);
@ -642,10 +642,10 @@ StatusOr<std::unique_ptr<Executable>> CpuCompiler::RunBackend(
// temporary buffers are required to run the computation.
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(
module.get(),
MakeUnique<SequentialHloOrdering>(module.get(), module_sequence),
BufferSizeBytesFunction(), memory_alignment));
BufferAssigner::Run(module.get(),
xla::MakeUnique<SequentialHloOrdering>(
module.get(), module_sequence),
BufferSizeBytesFunction(), memory_alignment));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.
XLA_VLOG_LINES(2, assignment->ToString());
@ -824,7 +824,8 @@ CpuCompiler::CompileAheadOfTime(std::vector<std::unique_ptr<HloModule>> modules,
TF_ASSIGN_OR_RETURN(
std::unique_ptr<BufferAssignment> assignment,
BufferAssigner::Run(
module, MakeUnique<SequentialHloOrdering>(module, module_sequence),
module,
xla::MakeUnique<SequentialHloOrdering>(module, module_sequence),
BufferSizeBytesFunction(), memory_alignment));
// BufferAssignment::ToString() includes a header, so no need for us to
// print one ourselves.

View File

@ -213,71 +213,75 @@ bool RegisterKnownJITSymbols() {
#undef REGISTER_CPU_RUNTIME_SYMBOL
#define REGISTER_LIBM_SYMBOL(name) \
do { \
/* Register both the F32 and F64 variants of the libm symbol. */ \
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
registry->Register(#name, reinterpret_cast<void*>(name)); \
// Register both the f32 (float) and f64 (double) versions of a libm symbol.
// Unfortunately the double versions are overloaded on some systems, e.g.
// Mac so we need an explicit cast. This requires passing the function signature
// for that case.
#define REGISTER_LIBM_SYMBOL(name, double_sig) \
do { \
registry->Register(#name "f", reinterpret_cast<void*>(name##f)); \
registry->Register( \
#name, reinterpret_cast<void*>(static_cast<double_sig>(name))); \
} while (false)
REGISTER_LIBM_SYMBOL(acos);
REGISTER_LIBM_SYMBOL(acosh);
REGISTER_LIBM_SYMBOL(asin);
REGISTER_LIBM_SYMBOL(asinh);
REGISTER_LIBM_SYMBOL(atan);
REGISTER_LIBM_SYMBOL(atan2);
REGISTER_LIBM_SYMBOL(atanh);
REGISTER_LIBM_SYMBOL(cbrt);
REGISTER_LIBM_SYMBOL(ceil);
REGISTER_LIBM_SYMBOL(copysign);
REGISTER_LIBM_SYMBOL(cos);
REGISTER_LIBM_SYMBOL(cosh);
REGISTER_LIBM_SYMBOL(erf);
REGISTER_LIBM_SYMBOL(erfc);
REGISTER_LIBM_SYMBOL(exp);
REGISTER_LIBM_SYMBOL(exp2);
REGISTER_LIBM_SYMBOL(expm1);
REGISTER_LIBM_SYMBOL(fabs);
REGISTER_LIBM_SYMBOL(fdim);
REGISTER_LIBM_SYMBOL(floor);
REGISTER_LIBM_SYMBOL(fma);
REGISTER_LIBM_SYMBOL(fmax);
REGISTER_LIBM_SYMBOL(fmin);
REGISTER_LIBM_SYMBOL(fmod);
REGISTER_LIBM_SYMBOL(frexp);
REGISTER_LIBM_SYMBOL(hypot);
REGISTER_LIBM_SYMBOL(ilogb);
REGISTER_LIBM_SYMBOL(ldexp);
REGISTER_LIBM_SYMBOL(lgamma);
REGISTER_LIBM_SYMBOL(llrint);
REGISTER_LIBM_SYMBOL(llround);
REGISTER_LIBM_SYMBOL(log);
REGISTER_LIBM_SYMBOL(log10);
REGISTER_LIBM_SYMBOL(log1p);
REGISTER_LIBM_SYMBOL(log2);
REGISTER_LIBM_SYMBOL(logb);
REGISTER_LIBM_SYMBOL(lrint);
REGISTER_LIBM_SYMBOL(lround);
REGISTER_LIBM_SYMBOL(modf);
REGISTER_LIBM_SYMBOL(nan);
REGISTER_LIBM_SYMBOL(nearbyint);
REGISTER_LIBM_SYMBOL(nextafter);
REGISTER_LIBM_SYMBOL(nexttoward);
REGISTER_LIBM_SYMBOL(pow);
REGISTER_LIBM_SYMBOL(remainder);
REGISTER_LIBM_SYMBOL(remquo);
REGISTER_LIBM_SYMBOL(rint);
REGISTER_LIBM_SYMBOL(round);
REGISTER_LIBM_SYMBOL(scalbln);
REGISTER_LIBM_SYMBOL(scalbn);
REGISTER_LIBM_SYMBOL(sin);
REGISTER_LIBM_SYMBOL(sincos);
REGISTER_LIBM_SYMBOL(sinh);
REGISTER_LIBM_SYMBOL(sqrt);
REGISTER_LIBM_SYMBOL(tan);
REGISTER_LIBM_SYMBOL(tanh);
REGISTER_LIBM_SYMBOL(tgamma);
REGISTER_LIBM_SYMBOL(trunc);
REGISTER_LIBM_SYMBOL(acos, double (*)(double));
REGISTER_LIBM_SYMBOL(acosh, double (*)(double));
REGISTER_LIBM_SYMBOL(asin, double (*)(double));
REGISTER_LIBM_SYMBOL(asinh, double (*)(double));
REGISTER_LIBM_SYMBOL(atan, double (*)(double));
REGISTER_LIBM_SYMBOL(atan2, double (*)(double, double));
REGISTER_LIBM_SYMBOL(atanh, double (*)(double));
REGISTER_LIBM_SYMBOL(cbrt, double (*)(double));
REGISTER_LIBM_SYMBOL(ceil, double (*)(double));
REGISTER_LIBM_SYMBOL(copysign, double (*)(double, double));
REGISTER_LIBM_SYMBOL(cos, double (*)(double));
REGISTER_LIBM_SYMBOL(cosh, double (*)(double));
REGISTER_LIBM_SYMBOL(erf, double (*)(double));
REGISTER_LIBM_SYMBOL(erfc, double (*)(double));
REGISTER_LIBM_SYMBOL(exp, double (*)(double));
REGISTER_LIBM_SYMBOL(exp2, double (*)(double));
REGISTER_LIBM_SYMBOL(expm1, double (*)(double));
REGISTER_LIBM_SYMBOL(fabs, double (*)(double));
REGISTER_LIBM_SYMBOL(fdim, double (*)(double, double));
REGISTER_LIBM_SYMBOL(floor, double (*)(double));
REGISTER_LIBM_SYMBOL(fma, double (*)(double, double, double));
REGISTER_LIBM_SYMBOL(fmax, double (*)(double, double));
REGISTER_LIBM_SYMBOL(fmin, double (*)(double, double));
REGISTER_LIBM_SYMBOL(fmod, double (*)(double, double));
REGISTER_LIBM_SYMBOL(frexp, double (*)(double, int*));
REGISTER_LIBM_SYMBOL(hypot, double (*)(double, double));
REGISTER_LIBM_SYMBOL(ilogb, int (*)(double));
REGISTER_LIBM_SYMBOL(ldexp, double (*)(double, int));
REGISTER_LIBM_SYMBOL(lgamma, double (*)(double));
REGISTER_LIBM_SYMBOL(llrint, long long (*)(double));
REGISTER_LIBM_SYMBOL(llround, long long (*)(double));
REGISTER_LIBM_SYMBOL(log, double (*)(double));
REGISTER_LIBM_SYMBOL(log10, double (*)(double));
REGISTER_LIBM_SYMBOL(log1p, double (*)(double));
REGISTER_LIBM_SYMBOL(log2, double (*)(double));
REGISTER_LIBM_SYMBOL(logb, double (*)(double));
REGISTER_LIBM_SYMBOL(lrint, long (*)(double));
REGISTER_LIBM_SYMBOL(lround, long (*)(double));
REGISTER_LIBM_SYMBOL(modf, double (*)(double, double*));
REGISTER_LIBM_SYMBOL(nan, double (*)(const char*));
REGISTER_LIBM_SYMBOL(nearbyint, double (*)(double));
REGISTER_LIBM_SYMBOL(nextafter, double (*)(double, double));
REGISTER_LIBM_SYMBOL(nexttoward, double (*)(double, long double));
REGISTER_LIBM_SYMBOL(pow, double (*)(double, double));
REGISTER_LIBM_SYMBOL(remainder, double (*)(double, double));
REGISTER_LIBM_SYMBOL(remquo, double (*)(double, double, int*));
REGISTER_LIBM_SYMBOL(rint, double (*)(double));
REGISTER_LIBM_SYMBOL(round, double (*)(double));
REGISTER_LIBM_SYMBOL(scalbln, double (*)(double, long));
REGISTER_LIBM_SYMBOL(scalbn, double (*)(double, int));
REGISTER_LIBM_SYMBOL(sin, double (*)(double));
REGISTER_LIBM_SYMBOL(sincos, void (*)(double, double*, double*));
REGISTER_LIBM_SYMBOL(sinh, double (*)(double));
REGISTER_LIBM_SYMBOL(sqrt, double (*)(double));
REGISTER_LIBM_SYMBOL(tan, double (*)(double));
REGISTER_LIBM_SYMBOL(tanh, double (*)(double));
REGISTER_LIBM_SYMBOL(tgamma, double (*)(double));
REGISTER_LIBM_SYMBOL(trunc, double (*)(double));
#undef REGISTER_LIBM_SYMBOL

View File

@ -450,7 +450,7 @@ message ConvolutionDimensionNumbers {
message ConvolveRequest {
ComputationDataHandle lhs = 2;
ComputationDataHandle rhs = 3; // This is the filter/kernel.
Window window = 4; // Describes the filter/kenel.
Window window = 4; // Describes the filter/kernel.
ConvolutionDimensionNumbers dimension_numbers = 5;
}

View File

@ -56,7 +56,7 @@ class BatchFeatures {
*num_sparse_int_features = sparse_int_feature_columns_.size();
if (*num_dense_float_features == 0 && *num_sparse_float_features == 0 &&
*num_sparse_int_features == 0) {
return errors::FailedPrecondition("Not intialized yet.");
return errors::FailedPrecondition("Not initialized yet.");
}
return Status::OK();
}

View File

@ -63,7 +63,7 @@ class SparseFloatFeatureColumn {
public:
void Reserve(const int32 size) {
if (!single_dimensional_) {
mutlidimensional_values.Reserve(size);
multidimensional_values.Reserve(size);
}
}
@ -76,7 +76,7 @@ class SparseFloatFeatureColumn {
DCHECK_EQ(0, feature_idx);
single_value_ = value;
} else {
mutlidimensional_values.Add(feature_idx, value);
multidimensional_values.Add(feature_idx, value);
}
initialized_ = true;
}
@ -84,7 +84,7 @@ class SparseFloatFeatureColumn {
void Clear() {
single_dimensional_ = false;
initialized_ = false;
mutlidimensional_values.Clear();
multidimensional_values.Clear();
}
OptionalValue<T> operator[](int feature_idx) const {
@ -94,7 +94,7 @@ class SparseFloatFeatureColumn {
if (single_dimensional_) {
return OptionalValue<T>(single_value_);
} else {
return mutlidimensional_values[feature_idx];
return multidimensional_values[feature_idx];
}
}
@ -102,7 +102,7 @@ class SparseFloatFeatureColumn {
bool single_dimensional_;
bool initialized_;
T single_value_;
SparseMultidimensionalValues<T> mutlidimensional_values;
SparseMultidimensionalValues<T> multidimensional_values;
};
// Holds data for one example and enables lookup by feature column.

View File

@ -96,6 +96,10 @@ class IndicesRowIterator
return (row_idx_ != other.row_idx_);
}
bool operator<(const IndicesRowIterator& other) const {
return (row_idx_ < other.row_idx_);
}
bool operator==(const IndicesRowIterator& other) const {
QCHECK_EQ(iter_, other.iter_);
return (row_idx_ == other.row_idx_);

View File

@ -45,4 +45,5 @@ ExternalProject_Add(re2
endif()
-DCMAKE_BUILD_TYPE:STRING=Release
-DCMAKE_INSTALL_PREFIX:STRING=${re2_INSTALL}
-DRE2_BUILD_TESTING:BOOL=OFF
)

View File

@ -95,10 +95,18 @@ if(WIN32)
add_dependencies(tensorflow tensorflow_static)
endif(WIN32)
install(TARGETS tensorflow
target_include_directories(tensorflow PUBLIC
$<INSTALL_INTERFACE:include/>
$<INSTALL_INTERFACE:include/external/nsync/public>)
install(TARGETS tensorflow EXPORT tensorflow_export
RUNTIME DESTINATION bin
LIBRARY DESTINATION lib
ARCHIVE DESTINATION lib)
install(EXPORT tensorflow_export
FILE TensorflowConfig.cmake
DESTINATION lib/cmake)
# install necessary headers
# tensorflow headers

View File

@ -153,7 +153,7 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/data/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/image/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/python/keras/integration_test.py"
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/nearest_neighbor/python/kernel_tests/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/seq2seq/python/kernel_tests/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/stateless/python/kernel_tests/*_test.py"
@ -171,7 +171,6 @@ if (tensorflow_BUILD_PYTHON_TESTS)
"${tensorflow_source_dir}/tensorflow/contrib/graph_editor/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/bayesflow/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/framework/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/keras/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/distributions/*_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/learn/*_test.py"
)
@ -225,6 +224,9 @@ if (tensorflow_BUILD_PYTHON_TESTS)
# Numerical issues, calculations off.
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/concat_op_test.py"
"${tensorflow_source_dir}/tensorflow/contrib/factorization/python/ops/wals_test.py"
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/utils/data_utils_test.py"
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/backend_test.py"
"${tensorflow_source_dir}/tensorflow/python/keras/_impl/keras/preprocessing/image_test.py"
# Float division by zero
"${tensorflow_source_dir}/tensorflow/python/kernel_tests/benchmark_test.py"
# Flaky, for unknown reasons. Cannot reproduce in terminal. Revisit once we can get stack traces.

View File

@ -420,7 +420,7 @@ class CrfDecodeBackwardRnnCell(rnn_cell.RNNCell):
"""Initialize the CrfDecodeBackwardRnnCell.
Args:
num_tags: The number of tags.
num_tags: An integer. The number of tags.
"""
self._num_tags = num_tags

View File

@ -161,6 +161,7 @@ py_test(
size = "small",
srcs = ["flat_map_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
@ -278,6 +279,7 @@ py_test(
size = "medium",
srcs = ["map_dataset_op_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:dataset_ops",
@ -348,6 +350,7 @@ py_test(
size = "medium",
srcs = ["reader_dataset_ops_test.py"],
srcs_version = "PY2AND3",
tags = ["no_pip"],
deps = [
":dataset_serialization_test",
"//tensorflow/contrib/data/python/ops:readers",

View File

@ -293,7 +293,7 @@ class PoissonLogNormalQuadratureCompound(distribution_lib.Distribution):
# where,
#
# Z|v ~ interpolate_affine[v](distribution)
# V ~ mixture_distrubution
# V ~ mixture_distribution
#
# thus,
#

View File

@ -73,7 +73,7 @@ class Metric(object):
* `result()`: Computes and returns a final value for the metric
from the variables in `self`.
Decendants may override `aggregate()`, but usually won't need to. It
Descendants may override `aggregate()`, but usually won't need to. It
adds in the state from a list of metrics of the same type as `self`.
(Default is to sum all the variables.) Note that users should not call
`aggregate()`, it is for use by TensorFlow infrastructure.

View File

@ -183,7 +183,7 @@ def _wals_factorization_model_function(features, labels, mode, params):
# TRAIN mode:
if mode == model_fn.ModeKeys.TRAIN:
# Training consists of the folowing ops (controlled using a SweepHook).
# Training consists of the following ops (controlled using a SweepHook).
# Before a row sweep:
# row_update_prep_gramian_op
# initialize_row_update_op

View File

@ -47,10 +47,25 @@ cc_library(
alwayslink = 1,
)
cc_library(
name = "decode_video_op_cc",
srcs = ["decode_video_op.cc"],
copts = tf_copts(),
linkstatic = 1,
visibility = ["//visibility:private"],
deps = [
"//tensorflow/contrib/ffmpeg/default:ffmpeg_lib",
"//tensorflow/core:framework_headers_lib",
"//third_party/eigen3",
],
alwayslink = 1,
)
tf_custom_op_library(
name = "ffmpeg.so",
deps = [
":decode_audio_op_cc",
":decode_video_op_cc",
":encode_audio_op_cc",
],
)
@ -59,6 +74,7 @@ cc_library(
name = "ffmpeg_op_lib",
deps = [
":decode_audio_op_cc",
":decode_video_op_cc",
":encode_audio_op_cc",
],
)
@ -81,6 +97,15 @@ tf_gen_op_wrapper_py(
],
)
tf_gen_op_wrapper_py(
name = "decode_video_op_py",
require_shape_functions = True,
visibility = ["//visibility:private"],
deps = [
":decode_video_op_cc",
],
)
tf_py_test(
name = "decode_audio_op_test",
srcs = ["decode_audio_op_test.py"],
@ -115,6 +140,27 @@ tf_py_test(
tags = ["manual"],
)
tf_py_test(
name = "decode_video_op_test",
size = "small",
srcs = ["decode_video_op_test.py"],
additional_deps = [
":ffmpeg_ops_py",
"@six_archive//:six",
"//tensorflow/python:array_ops",
"//tensorflow/python:client_testlib",
"//tensorflow/python:platform",
"//tensorflow/python:image_ops",
],
data = [
":test_data",
],
tags = [
"manual",
"notap",
],
)
py_library(
name = "ffmpeg_ops_py",
srcs = [
@ -126,6 +172,7 @@ py_library(
visibility = ["//visibility:public"],
deps = [
":decode_audio_op_py",
":decode_video_op_py",
":encode_audio_op_py",
"//tensorflow/contrib/util:util_py",
"//tensorflow/python:framework_for_generated_wrappers",

View File

@ -26,9 +26,10 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_audio
from tensorflow.contrib.ffmpeg.ffmpeg_ops import decode_video
from tensorflow.contrib.ffmpeg.ffmpeg_ops import encode_audio
from tensorflow.python.util.all_util import remove_undocumented
_allowed_symbols = ['decode_audio', 'encode_audio']
_allowed_symbols = ['decode_audio', 'encode_audio', 'decode_video']
remove_undocumented(__name__, _allowed_symbols)

View File

@ -37,29 +37,6 @@ namespace {
// https://www.ffmpeg.org/ffmpeg-formats.html
const char* kValidFileFormats[] = {"mp3", "mp4", "ogg", "wav"};
// Writes binary data to a file.
Status WriteFile(const string& filename, tensorflow::StringPiece contents) {
Env& env = *Env::Default();
std::unique_ptr<WritableFile> file;
TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file));
TF_RETURN_IF_ERROR(file->Append(contents));
TF_RETURN_IF_ERROR(file->Close());
return Status::OK();
}
// Cleans up a file on destruction.
class FileDeleter {
public:
explicit FileDeleter(const string& filename) : filename_(filename) {}
~FileDeleter() {
Env& env = *Env::Default();
env.DeleteFile(filename_).IgnoreError();
}
private:
const string filename_;
};
/*
* Decoding implementation, shared across V1 and V2 ops. Creates a new
* output in the context.
@ -69,7 +46,7 @@ void Decode(OpKernelContext* context,
const string& file_format, const int32 samples_per_second,
const int32 channel_count) {
// Write the input data to a temp file.
const string temp_filename = GetTempFilename(file_format);
const string temp_filename = io::GetTempFilename(file_format);
OP_REQUIRES_OK(context, WriteFile(temp_filename, file_contents));
FileDeleter deleter(temp_filename);

View File

@ -0,0 +1,118 @@
// Copyright 2016 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
// =============================================================================
#include <stdlib.h>
#include <cstdio>
#include <set>
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/shape_inference.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/logging.h"
namespace tensorflow {
namespace ffmpeg {
class DecodeVideoOp : public OpKernel {
public:
explicit DecodeVideoOp(OpKernelConstruction* context) : OpKernel(context) {}
void Compute(OpKernelContext* context) override {
OP_REQUIRES(
context, context->num_inputs() == 1,
errors::InvalidArgument("DecodeVideo requires exactly 1 input."));
const Tensor& contents_tensor = context->input(0);
OP_REQUIRES(context, TensorShapeUtils::IsScalar(contents_tensor.shape()),
errors::InvalidArgument(
"contents must be a rank-0 tensor but got shape ",
contents_tensor.shape().DebugString()));
const tensorflow::StringPiece contents = contents_tensor.scalar<string>()();
// Write the input data to a temp file.
string extension;
const string temp_filename = io::GetTempFilename(extension);
OP_REQUIRES_OK(context, WriteFile(temp_filename, contents));
FileDeleter deleter(temp_filename);
uint32 width = 0;
uint32 height = 0;
uint32 frames = 0;
// Run FFmpeg on the data and verify results.
std::vector<uint8> output_data;
const Status result = ffmpeg::ReadVideoFile(temp_filename, &output_data,
&width, &height, &frames);
if (result.code() == error::Code::NOT_FOUND) {
OP_REQUIRES(
context, result.ok(),
errors::Unavailable("FFmpeg must be installed to run this op. FFmpeg "
"can be found at http://www.ffmpeg.org."));
} else if (result.code() == error::UNKNOWN) {
LOG(ERROR) << "Ffmpeg failed with error '" << result.error_message()
<< "'. Returning empty tensor.";
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(0, TensorShape({0, 0}), &output));
return;
} else {
OP_REQUIRES_OK(context, result);
}
OP_REQUIRES(context, !output_data.empty(),
errors::Unknown("No output created by FFmpeg."));
OP_REQUIRES(
context, output_data.size() == (frames * height * width * 3),
errors::Unknown("Output created by FFmpeg [", output_data.size(),
"] does not match description [", frames, ", ", height,
", ", width, ", 3]"));
Tensor* output = nullptr;
OP_REQUIRES_OK(context,
context->allocate_output(
0, TensorShape({frames, height, width, 3}), &output));
auto output_flat = output->flat<uint8>();
std::copy_n(output_data.begin(), output_data.size(), &output_flat(0));
}
};
REGISTER_KERNEL_BUILDER(Name("DecodeVideo").Device(DEVICE_CPU), DecodeVideoOp);
REGISTER_OP("DecodeVideo")
.Input("contents: string")
.Output("output: uint8")
.SetShapeFn([](shape_inference::InferenceContext* c) {
c->set_output(0, c->UnknownShapeOfRank(4));
return Status::OK();
})
.Doc(R"doc(
Processes the contents of an audio file into a tensor using FFmpeg to decode
the file.
One row of the tensor is created for each channel in the audio file. Each
channel contains audio samples starting at the beginning of the audio and
having `1/samples_per_second` time between them. If the `channel_count` is
different from the contents of the file, channels will be merged or created.
contents: The binary audio file contents, as a string or rank-0 string
tensor.
)doc");
} // namespace ffmpeg
} // namespace tensorflow

View File

@ -0,0 +1,69 @@
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
"""Tests for third_party.tensorflow.contrib.ffmpeg.decode_video_op."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import os.path
import six # pylint: disable=unused-import
from tensorflow.contrib import ffmpeg
from tensorflow.python.ops import image_ops
from tensorflow.python.platform import resource_loader
from tensorflow.python.platform import test
class DecodeVideoOpTest(test.TestCase):
def _loadFileAndTest(self, filename, width, height, frames, bmp_filename,
index):
"""Loads an video file and validates the output tensor.
Args:
filename: The filename of the input file.
width: The width of the video.
height: The height of the video.
frames: The frames of the video.
bmp_filename: The filename for the bmp file.
index: Index location inside the video.
"""
with self.test_session():
path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
filename)
with open(path, 'rb') as f:
contents = f.read()
bmp_path = os.path.join(resource_loader.get_data_files_path(), 'testdata',
bmp_filename)
with open(bmp_path, 'rb') as f:
bmp_contents = f.read()
image_op = image_ops.decode_bmp(bmp_contents)
image = image_op.eval()
self.assertEqual(image.shape, (height, width, 3))
video_op = ffmpeg.decode_video(contents)
video = video_op.eval()
self.assertEqual(video.shape, (frames, height, width, 3))
self.assertAllEqual(video[index, :, :, :], image)
def testMp4(self):
self._loadFileAndTest('small.mp4', 560, 320, 166, 'small_100.bmp', 99)
if __name__ == '__main__':
test.main()

View File

@ -16,6 +16,7 @@
#include "tensorflow/contrib/ffmpeg/ffmpeg_lib.h"
#include <errno.h>
#include <fcntl.h>
#include <stdlib.h>
#include <sys/stat.h>
#include <sys/types.h>
@ -25,6 +26,7 @@
#include <vector>
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/lib/strings/numbers.h"
#include "tensorflow/core/lib/strings/str_util.h"
#include "tensorflow/core/platform/cpu_info.h"
#include "tensorflow/core/platform/env.h"
@ -38,28 +40,45 @@ namespace {
const char kFfmpegExecutable[] = "ffmpeg";
const int32 kDefaultProbeSize = 5000000; // 5MB
std::vector<string> FfmpegCommandLine(const string& input_filename,
const string& output_filename,
const string& input_format_id,
int32 samples_per_second,
int32 channel_count) {
return {
"-nostats", // No additional progress display.
"-nostdin", // No interactive commands accepted.
"-f", input_format_id, // eg: "mp3"
"-probesize", StrCat(kDefaultProbeSize),
"-i", input_filename,
"-loglevel", "info", // Enable verbose logging to support debugging.
"-map_metadata", "-1", // Copy global metadata from input to output.
"-vn", // No video recording.
"-ac:a:0", StrCat(channel_count),
"-ar:a:0", StrCat(samples_per_second),
// Output set (in several ways) to signed 16-bit little-endian ints.
"-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le",
"-sn", // No subtitle recording.
"-y", // Overwrite output file.
StrCat(output_filename)
};
std::vector<string> FfmpegAudioCommandLine(const string& input_filename,
const string& output_filename,
const string& input_format_id,
int32 samples_per_second,
int32 channel_count) {
return {"-nostats", // No additional progress display.
"-nostdin", // No interactive commands accepted.
"-f", input_format_id, // eg: "mp3"
"-probesize", StrCat(kDefaultProbeSize), "-i", input_filename,
"-loglevel", "info", // Enable verbose logging to support debugging.
"-map_metadata", "-1", // Copy global metadata from input to output.
"-vn", // No video recording.
"-ac:a:0", StrCat(channel_count), "-ar:a:0",
StrCat(samples_per_second),
// Output set (in several ways) to signed 16-bit little-endian ints.
"-codec:a:0", "pcm_s16le", "-sample_fmt", "s16", "-f", "s16le",
"-sn", // No subtitle recording.
"-y", // Overwrite output file.
StrCat(output_filename)};
}
std::vector<string> FfmpegVideoCommandLine(const string& input_filename,
const string& output_filename) {
return {"-nostats", // No additional progress display.
"-nostdin", // No interactive commands accepted.
"-i",
input_filename,
"-f",
"image2pipe",
"-probesize",
StrCat(kDefaultProbeSize),
"-loglevel",
"info", // Enable verbose logging to support debugging.
"-vcodec",
"rawvideo",
"-pix_fmt",
"rgb24",
"-y", // Overwrite output file.
StrCat(output_filename)};
}
// Is a named binary installed and executable by the current process?
@ -106,7 +125,7 @@ bool IsBinaryInstalled(const string& binary_name) {
::execvp(kFfmpegExecutable, args_chars.data());
// exec only returns on error.
const int error = errno;
LOG(ERROR) << "FFmpeg could not be executed: " << error;
LOG(ERROR) << "FFmpeg could not be executed: " << strerror(error);
::_exit(error);
}
@ -198,52 +217,101 @@ string BuildWavFile(int32 samples_per_second, int32 channel_count,
return data;
}
// Returns a unique number every time it is called.
int64 UniqueId() {
static mutex mu(LINKER_INITIALIZED);
static int64 id = 0;
mutex_lock l(mu);
return ++id;
Status ReadInfoFile(const string& filename, uint32* width, uint32* height,
uint32* frames) {
string data;
TF_QCHECK_OK(ReadFileToString(Env::Default(), filename, &data))
<< "Could not read FFmpeg file: " << filename;
bool in_output = false;
bool in_mapping = false;
uint32 frames_value = 0;
uint32 height_value = 0;
uint32 width_value = 0;
for (const string& line : str_util::Split(data, '\n')) {
// Output starts with the first line of `Output #..`.
// Further processing output region starts next line so we could continue
// the loop.
if (!in_output && line.find("Output #") == 0) {
in_output = true;
in_mapping = false;
continue;
}
// Stream mapping starts with the first line of `Stream mapping`, it also
// signals the end of Output section.
// Further processing of stream mapping region starts next line so we could
// continue the loop.
if (!in_mapping && line.find("Stream mapping:") == 0) {
in_output = false;
in_mapping = true;
continue;
}
if (in_output) {
// We only look for the first stream in output `Stream #0`.
// Once processed we will not further process output section.
if (line.find(" Stream #") == 0) {
size_t p = line.find(", rgb24, ", 24);
if (p != std::string::npos) {
string rgb24 = line.substr(p + 9, line.find(" ", p + 9));
rgb24 = rgb24.substr(0, rgb24.find(","));
string rgb24_width = rgb24.substr(0, rgb24.find("x"));
string rgb24_height = rgb24.substr(rgb24_width.length() + 1);
if (strings::safe_strtou32(rgb24_width, &width_value) &&
strings::safe_strtou32(rgb24_height, &height_value)) {
in_output = false;
}
}
}
continue;
}
if (in_mapping) {
// We only look for the first stream mapping to have the number of the
// frames.
// Once processed we will not further process stream mapping section.
if (line.find("frame= ") == 0) {
string number = line.substr(8, line.find(" ", 8));
number = number.substr(0, number.find(" "));
if (strings::safe_strtou32(number, &frames_value)) {
in_mapping = false;
}
}
continue;
}
}
if (frames_value == 0 || height_value == 0 || width_value == 0) {
return errors::Unknown("Not enough video info returned by FFmpeg [",
frames_value, ", ", height_value, ", ", width_value,
", 3]");
}
*width = width_value;
*height = height_value;
*frames = frames_value;
return Status::OK();
}
} // namespace
string GetTempFilename(const string& extension) {
for (const char* dir : std::vector<const char*>(
{getenv("TEST_TMPDIR"), getenv("TMPDIR"), getenv("TMP"), "/tmp"})) {
if (!dir || !dir[0]) {
continue;
}
struct stat statbuf;
if (!stat(dir, &statbuf) && S_ISDIR(statbuf.st_mode)) {
// UniqueId is added here because mkstemps is not as thread safe as it
// looks. https://github.com/tensorflow/tensorflow/issues/5804 shows
// the problem.
string tmp_filepath = io::JoinPath(
dir,
StrCat("tmp_file_tensorflow_", UniqueId(), "_XXXXXX.", extension));
int fd = mkstemps(&tmp_filepath[0], extension.length() + 1);
if (fd < 0) {
LOG(FATAL) << "Failed to create temp file.";
} else {
close(fd);
return tmp_filepath;
}
}
}
LOG(FATAL) << "No temp directory found.";
FileDeleter::~FileDeleter() {
Env& env = *Env::Default();
env.DeleteFile(filename_).IgnoreError();
}
Status ReadAudioFile(const string& filename,
const string& audio_format_id,
int32 samples_per_second,
int32 channel_count,
Status WriteFile(const string& filename, StringPiece contents) {
Env& env = *Env::Default();
std::unique_ptr<WritableFile> file;
TF_RETURN_IF_ERROR(env.NewWritableFile(filename, &file));
TF_RETURN_IF_ERROR(file->Append(contents));
TF_RETURN_IF_ERROR(file->Close());
return Status::OK();
}
Status ReadAudioFile(const string& filename, const string& audio_format_id,
int32 samples_per_second, int32 channel_count,
std::vector<float>* output_samples) {
// Create an argument list.
string output_filename = GetTempFilename("raw");
string output_filename = io::GetTempFilename("raw");
const std::vector<string> args =
FfmpegCommandLine(filename, output_filename, audio_format_id,
samples_per_second, channel_count);
FfmpegAudioCommandLine(filename, output_filename, audio_format_id,
samples_per_second, channel_count);
// Unfortunately, it's impossible to differentiate an exec failure due to the
// binary being missing and an error from the binary's execution. Therefore,
@ -256,7 +324,8 @@ Status ReadAudioFile(const string& filename,
// Execute ffmpeg and report errors.
pid_t child_pid = ::fork();
if (child_pid < 0) {
return Status(error::Code::UNKNOWN, StrCat("fork failed: ", errno));
return Status(error::Code::UNKNOWN,
StrCat("fork failed: ", strerror(errno)));
}
if (child_pid == 0) {
ExecuteFfmpeg(args);
@ -285,5 +354,63 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second,
return Status::OK();
}
Status ReadVideoFile(const string& filename, std::vector<uint8>* output_data,
uint32* width, uint32* height, uint32* frames) {
if (!IsBinaryInstalled(kFfmpegExecutable)) {
return Status(error::Code::NOT_FOUND, StrCat("FFmpeg could not be found."));
}
string output_filename = io::GetTempFilename("raw");
string stderr_filename = io::GetTempFilename("err");
// Create an argument list.
const std::vector<string> args =
FfmpegVideoCommandLine(filename, output_filename);
// Execute ffmpeg and report errors.
pid_t child_pid = ::fork();
if (child_pid < 0) {
return Status(error::Code::UNKNOWN,
StrCat("fork failed: ", strerror(errno)));
}
if (child_pid == 0) {
const int fd =
open(stderr_filename.c_str(), O_RDWR | O_CREAT | O_APPEND, 0600);
if (fd < 0) {
const int error = errno;
LOG(ERROR) << "FFmpeg stderr file could not be created: "
<< strerror(error);
::_exit(error);
}
close(STDERR_FILENO);
dup2(fd, STDERR_FILENO);
ExecuteFfmpeg(args);
} else {
int status_code;
if (::waitpid(child_pid, &status_code, 0) < 0) {
return Status(error::Code::UNKNOWN,
StrCat("waitpid failed: ", strerror(errno)));
}
if (status_code) {
return Status(error::Code::UNKNOWN,
StrCat("FFmpeg execution failed: ", status_code));
}
TF_QCHECK_OK(ReadInfoFile(stderr_filename, width, height, frames))
<< "Could not read FFmpeg stderr file: " << stderr_filename;
string raw_data;
TF_QCHECK_OK(ReadFileToString(Env::Default(), output_filename, &raw_data))
<< "Could not read FFmpeg output file: " << output_filename;
output_data->resize(raw_data.size());
std::copy_n(raw_data.data(), raw_data.size(), output_data->begin());
TF_QCHECK_OK(Env::Default()->DeleteFile(output_filename))
<< output_filename;
TF_QCHECK_OK(Env::Default()->DeleteFile(stderr_filename))
<< stderr_filename;
return Status::OK();
}
}
} // namespace ffmpeg
} // namespace tensorflow

View File

@ -21,6 +21,7 @@
#include <vector>
#include "tensorflow/core/lib/core/threadpool.h"
#include "tensorflow/core/lib/io/path.h"
#include "tensorflow/core/platform/env.h"
#include "tensorflow/core/platform/mutex.h"
#include "tensorflow/core/platform/test.h"
@ -49,7 +50,7 @@ TEST(FfmpegLibTest, TestTempDirectoryThreading) {
pool.Schedule([&mu, &temp_filenames, environment]() {
std::array<string, kStringsPerItem> buffer;
for (int32 j = 0; j < kStringsPerItem; ++j) {
buffer[j] = GetTempFilename("mp3");
buffer[j] = io::GetTempFilename("mp3");
TF_QCHECK_OK(environment->DeleteFile(buffer[j]));
}
mutex_lock l(mu);

View File

@ -24,16 +24,24 @@
namespace tensorflow {
namespace ffmpeg {
// Gets a temp filename in an appropriate location.
string GetTempFilename(const string& extension);
// Cleans up a file on destruction.
class FileDeleter {
public:
explicit FileDeleter(const string& filename) : filename_(filename) {}
~FileDeleter();
private:
const string filename_;
};
// Writes binary data to a file.
Status WriteFile(const string& filename, tensorflow::StringPiece contents);
// Reads an audio file using ffmpeg and converts it into an array of samples in
// [-1.0, 1.0]. If there are multiple channels in the audio then each frame will
// contain a separate sample for each channel. Frames are ordered by time.
Status ReadAudioFile(const string& filename,
const string& audio_format_id,
int32 samples_per_second,
int32 channel_count,
Status ReadAudioFile(const string& filename, const string& audio_format_id,
int32 samples_per_second, int32 channel_count,
std::vector<float>* output_samples);
// Creates an audio file using ffmpeg in a specific format. The samples are in
@ -45,6 +53,11 @@ Status CreateAudioFile(const string& audio_format_id, int32 bits_per_second,
int32 samples_per_second, int32 channel_count,
const std::vector<float>& samples, string* output_data);
// Reads an video file using ffmpeg adn converts it into a RGB24 in uint8
// [frames, height, width, 3]. The w, h, and frames are obtained from ffmpeg.
Status ReadVideoFile(const string& filename, std::vector<uint8>* output_data,
uint32* width, uint32* height, uint32* frames);
} // namespace ffmpeg
} // namespace tensorflow

View File

@ -19,6 +19,7 @@ from __future__ import division
from __future__ import print_function
from tensorflow.contrib.ffmpeg.ops import gen_decode_audio_op_py
from tensorflow.contrib.ffmpeg.ops import gen_decode_video_op_py
from tensorflow.contrib.ffmpeg.ops import gen_encode_audio_op_py
from tensorflow.contrib.util import loader
from tensorflow.python.framework import ops
@ -89,3 +90,19 @@ def encode_audio(audio, file_format=None, samples_per_second=None):
ops.NotDifferentiable('EncodeAudio')
def decode_video(contents):
"""Create an op that decodes the contents of a video file.
Args:
contents: The binary contents of the video file to decode. This is a
scalar.
Returns:
A rank-4 `Tensor` that has `[frames, height, width, 3]` RGB as output.
"""
return gen_decode_video_op_py.decode_video(contents)
ops.NotDifferentiable('DecodeVideo')

Binary file not shown.

Binary file not shown.

After

Width:  |  Height:  |  Size: 525 KiB

View File

@ -24,12 +24,14 @@ import six
# pylint: disable=unused-import
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.python.framework import ops
from tensorflow.python.framework.graph_util_impl import _assert_nodes_are_present
from tensorflow.python.framework.graph_util_impl import _bfs_for_reachable_nodes
from tensorflow.python.framework.graph_util_impl import _extract_graph_summary
from tensorflow.python.framework.graph_util_impl import _node_name
__all__ = ["fuse_op"]
__all__ = ["fuse_op", "get_placeholders"]
def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
@ -126,3 +128,27 @@ def fuse_op(graph_def, input_nodes, output_nodes, output_dtypes,
out.library.CopyFrom(graph_def.library)
out.versions.CopyFrom(graph_def.versions)
return out
def get_placeholders(graph):
"""Get placeholders of a graph.
Args:
graph: A tf.Graph.
Returns:
A list contains all placeholders of given graph.
Raises:
TypeError: If `graph` is not a tensorflow graph.
"""
if not isinstance(graph, ops.Graph):
raise TypeError("Input graph needs to be a Graph: %s" % graph)
# For each placeholder() call, there is a corresponding
# operation of type 'Placeholder' registered to the graph.
# The return value (a Tensor) of placeholder() is the
# first output of this operation in fact.
operations = graph.get_operations()
result = [i.outputs[0] for i in operations if i.type == "Placeholder"]
return result

View File

@ -21,6 +21,9 @@ from tensorflow.contrib.framework.python.framework import graph_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
from tensorflow.core.framework import types_pb2
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.platform import test
@ -81,5 +84,16 @@ class GraphUtilTest(test.TestCase):
self.assertEqual(fused_graph_def.node[4].name, 'E')
class GetPlaceholdersTest(test.TestCase):
def test_get_placeholders(self):
with ops.Graph().as_default() as g:
placeholders = [array_ops.placeholder(dtypes.float32) for _ in range(5)]
results = graph_util.get_placeholders(g)
self.assertEqual(
sorted(placeholders, key=lambda x: x._id), # pylint: disable=protected-access
sorted(results, key=lambda x: x._id)) # pylint: disable=protected-access
if __name__ == '__main__':
test.main()

View File

@ -422,7 +422,7 @@ def gan_loss(
ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
model, add_summaries=add_summaries)
dis_loss += aux_cond_discriminator_weight * ac_disc_loss
# Gathers auxilliary losses.
# Gathers auxiliary losses.
if model.generator_scope:
gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
else:

View File

@ -2561,7 +2561,10 @@ def separable_convolution2d(
regularizer=weights_regularizer,
trainable=trainable,
collections=weights_collections)
strides = [1, stride_h, stride_w, 1]
strides = [1, 1, stride_h,
stride_w] if data_format.startswith('NC') else [
1, stride_h, stride_w, 1
]
outputs = nn.depthwise_conv2d(inputs, depthwise_weights, strides, padding,
rate=utils.two_element_tuple(rate),

View File

@ -3326,16 +3326,24 @@ class SeparableConv2dTest(test.TestCase):
for model_variable in model_variables:
self.assertEqual(trainable, model_variable in trainable_variables)
def testConvNCHW(self):
for num_filters, correct_output_filters in [(None, 6), (8, 8)]:
def testSepConvNCHW(self):
for num_filters, correct_output_filters in zip((None, 5), (6, 5)):
with self.test_session():
batch, height, width = 4, 5, 6
batch, height, width = 4, 10, 12
kernel_dim, stride = 3, 2
images = random_ops.random_uniform((batch, 3, height, width), seed=1)
output = layers_lib.separable_conv2d(
images, num_filters, [3, 3], 2, padding='VALID', data_format='NCHW')
self.assertListEqual(
output.get_shape().as_list(), [batch, correct_output_filters,
height - 2, width - 2])
images,
num_outputs=num_filters,
kernel_size=[kernel_dim, kernel_dim],
depth_multiplier=2,
stride=stride,
padding='VALID',
data_format='NCHW')
self.assertListEqual(output.get_shape().as_list(), [
batch, correct_output_filters, (height - kernel_dim + 1) // stride,
(width - kernel_dim + 1) // stride
])
class ScaleGradientTests(test.TestCase):

View File

@ -191,6 +191,9 @@ filegroup(
exclude = [
"**/METADATA",
"**/OWNERS",
"downloads",
"examples",
"gen",
],
),
visibility = ["//tensorflow:__subpackages__"],

View File

@ -0,0 +1,147 @@
# Find where we're running from, so we can store generated files here.
ifeq ($(origin MAKEFILE_DIR), undefined)
MAKEFILE_DIR := $(shell dirname $(realpath $(lastword $(MAKEFILE_LIST))))
endif
# Try to figure out the host system
HOST_OS :=
ifeq ($(OS),Windows_NT)
HOST_OS = WINDOWS
else
UNAME_S := $(shell uname -s)
ifeq ($(UNAME_S),Linux)
HOST_OS := LINUX
endif
ifeq ($(UNAME_S),Darwin)
HOST_OS := OSX
endif
endif
ARCH := $(shell if [[ $(shell uname -m) =~ i[345678]86 ]]; then echo x86_32; else echo $(shell uname -m); fi)
# Where compiled objects are stored.
OBJDIR := $(MAKEFILE_DIR)/gen/obj/
BINDIR := $(MAKEFILE_DIR)/gen/bin/
LIBDIR := $(MAKEFILE_DIR)/gen/lib/
GENDIR := $(MAKEFILE_DIR)/gen/obj/
# Settings for the host compiler.
CXX := $(CC_PREFIX) gcc
CXXFLAGS := --std=c++11 -O3 -DNDEBUG
CC := $(CC_PREFIX) gcc
CFLAGS :=
LDOPTS :=
LDOPTS += -L/usr/local/lib
ARFLAGS := -r
INCLUDES := \
-I. \
-I$(MAKEFILE_DIR)/../../../ \
-I$(MAKEFILE_DIR)/downloads/ \
-I$(MAKEFILE_DIR)/downloads/eigen \
-I$(MAKEFILE_DIR)/downloads/gemmlowp \
-I$(MAKEFILE_DIR)/downloads/neon_2_sse \
-I$(MAKEFILE_DIR)/downloads/farmhash/src \
-I$(MAKEFILE_DIR)/downloads/flatbuffers/include \
-I$(GENDIR)
# This is at the end so any globally-installed frameworks like protobuf don't
# override local versions in the source tree.
INCLUDES += -I/usr/local/include
LIBS := \
-lstdc++ \
-lpthread \
-lm \
-lz
# If we're on Linux, also link in the dl library.
ifeq ($(OS),LINUX)
LIBS += -ldl -lpthread
endif
include $(MAKEFILE_DIR)/ios_makefile.inc
# This library is the main target for this makefile. It will contain a minimal
# runtime that can be linked in to other programs.
LIB_NAME := libtensorflow-lite.a
LIB_PATH := $(LIBDIR)$(LIB_NAME)
# A small example program that shows how to link against the library.
BENCHMARK_PATH := $(BINDIR)benchmark_model
BENCHMARK_SRCS := \
tensorflow/contrib/lite/tools/benchmark_model.cc
BENCHMARK_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(BENCHMARK_SRCS))))
# What sources we want to compile, must be kept in sync with the main Bazel
# build files.
CORE_CC_ALL_SRCS := \
$(wildcard tensorflow/contrib/lite/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.cc) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.cc) \
$(wildcard tensorflow/contrib/lite/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/optimized/*.c) \
$(wildcard tensorflow/contrib/lite/kernels/internal/reference/*.c) \
$(wildcard tensorflow/contrib/lite/downloads/farmhash/src/farmhash.cc)
# Remove any duplicates.
CORE_CC_ALL_SRCS := $(sort $(CORE_CC_ALL_SRCS))
CORE_CC_EXCLUDE_SRCS := \
$(wildcard tensorflow/contrib/lite/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/*/*/*/*test.cc) \
$(wildcard tensorflow/contrib/lite/kernels/test_util.cc) \
$(BENCHMARK_SRCS)
# Filter out all the excluded files.
TF_LITE_CC_SRCS := $(filter-out $(CORE_CC_EXCLUDE_SRCS), $(CORE_CC_ALL_SRCS))
# File names of the intermediate files target compilation generates.
TF_LITE_CC_OBJS := $(addprefix $(OBJDIR), \
$(patsubst %.cc,%.o,$(patsubst %.c,%.o,$(TF_LITE_CC_SRCS))))
LIB_OBJS := $(TF_LITE_CC_OBJS)
# For normal manually-created TensorFlow C++ source files.
$(OBJDIR)%.o: %.cc
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) -c $< -o $@
# For normal manually-created TensorFlow C++ source files.
$(OBJDIR)%.o: %.c
@mkdir -p $(dir $@)
$(CC) $(CCFLAGS) $(INCLUDES) -c $< -o $@
# The target that's compiled if there's no command-line arguments.
all: $(LIB_PATH) $(BENCHMARK_PATH)
# Gathers together all the objects we've compiled into a single '.a' archive.
$(LIB_PATH): $(LIB_OBJS)
@mkdir -p $(dir $@)
$(AR) $(ARFLAGS) $(LIB_PATH) $(LIB_OBJS)
$(BENCHMARK_PATH): $(BENCHMARK_OBJS) $(LIB_PATH)
@mkdir -p $(dir $@)
$(CXX) $(CXXFLAGS) $(INCLUDES) \
-o $(BENCHMARK_PATH) $(BENCHMARK_OBJS) \
$(LIBFLAGS) $(LIB_PATH) $(LDFLAGS) $(LIBS)
# Gets rid of all generated files.
clean:
rm -rf $(MAKEFILE_DIR)/gen
# Gets rid of target files only, leaving the host alone. Also leaves the lib
# directory untouched deliberately, so we can persist multiple architectures
# across builds for iOS and Android.
cleantarget:
rm -rf $(OBJDIR)
rm -rf $(BINDIR)
$(DEPDIR)/%.d: ;
.PRECIOUS: $(DEPDIR)/%.d
-include $(patsubst %,$(DEPDIR)/%.d,$(basename $(TF_CC_SRCS)))

View File

@ -1,5 +1,5 @@
# TensorFlow Lite
TensorFlow Lite is TensorFlows lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration.
TensorFlow Lite is TensorFlow's lightweight solution for mobile and embedded devices. It enables low-latency inference of on-device machine learning models with a small binary size and fast performance supporting hardware acceleration.
TensorFlow Lite uses many techniques for achieving low latency like optimizing the kernels for specific mobile apps, pre-fused activations, quantized kernels that allow smaller and faster (fixed-point math) models, and in the future, leverage specialized machine learning hardware to get the best possible performance for a particular model on a particular device.
@ -20,18 +20,18 @@ In the demo app, inference is done using the TensorFlow Lite Java API. The demo
The fastest path to trying the demo, is to download the pre-built binary
[TfLiteCameraDemo.apk](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk)
Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the cameras field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified.
Once the apk is installed, click the app icon to start the app. The first-time the app is opened, the app asks for runtime permissions to access the device camera. The demo app opens the back-camera of the device and recognizes the objects in the camera's field of view. At the bottom of the image (or at the left of the image if the device is in landscape mode), it shows the latency of classification and the top three objects classified.
## Building in Android Studio using TensorFlow Lite AAR from JCenter
The simplest way to compile the demo app, and try out changes to the project code is to use AndroidStudio.
- Install the latest version of Android Studio 3 as specified [here](https://developer.android.com/studio/index.html).
- Make sure the Android SDK version is greater than 26 and NDK version is greater than 14 (in the Android Studio Settings).
- Import the tensorflow/contrib/lite/java/demo directory as a new Android Studio project.
- Import the `tensorflow/contrib/lite/java/demo` directory as a new Android Studio project.
- Click through installing all the Gradle extensions it requests.
- Download the quantized Mobilenet TensorFlow Lite model from [here](https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip)
- unzip and copy mobilenet_quant_v1_224.tflite to the assets directory:
tensorflow/contrib/lite/java/demo/app/src/main/assets/
`tensorflow/contrib/lite/java/demo/app/src/main/assets/`
- Build and run the demo app
## Building TensorFlow Lite and the demo app from source
@ -43,39 +43,45 @@ The simplest way to compile the demo app, and try out changes to the project cod
### Install Bazel
If bazel is not installed on your system, install it now by following [these directions](https://bazel.build/versions/master/docs/install.html)
NOTE: Bazel does not currently support building for Android on Windows. Full support for gradle/cmake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead.
NOTE: Bazel does not fully support building Android on Windows yet. Full support for Gradle/CMake builds is coming soon, but in the meantime Windows users should download the [prebuilt binary](https://storage.googleapis.com/download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk) instead.
### Install Android NDK and SDK
Bazel is the primary build system for TensorFlow. Bazel and the Android NDK and SDK must be installed on your system.
- Install the latest version of Bazel as per the instructions on the [Bazel website](https://bazel.build/versions/master/docs/install.html)
- The Android NDK is required to build the native (C/C++) TensorFlow code. The current recommended version is 14b, which may be found [here](https://developer.android.com/tools/revisions/build-tools.html).
- The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TensorFlow Android demo (though it will run on API >= 21 devices).
- The Android NDK is required to build the native (C/C++) TensorFlow Lite code. The current recommended version is 14b, which can be found [here](https://developer.android.com/ndk/downloads/older_releases.html#ndk-14b-downloads).
- The Android SDK and build tools may be obtained [here](https://developer.android.com/tools/revisions/build-tools.html), or alternatively as part of [Android Studio](https://developer.android.com/studio/index.html). Build tools API >= 23 is required to build the TF Android demo (though it will run on API >= 21 devices).
- In the root of the TensorFlow repository update the `WORKSPACE` file with the `api_level` and location of the SDK and NDK. If you installed it with AndroidStudio the SDK path can be found in the SDK manager, and the default NDK path is:`{SDK path}/ndk-bundle.`
```
Android_sdk_repository (
name = "androidsdk",
api_level = 23,
build_tools_version = "23.0.2",
path = "/home/xxxx/android-sdk-linux/", )
android_sdk_repository (
name = "androidsdk",
api_level = 23,
build_tools_version = "23.0.2",
path = "/home/xxxx/android-sdk-linux/",
)
android_ndk_repository(
name="androidndk",
path="/home/xxxx/android-ndk-r10e/",
api_level=19)
name = "androidndk",
path = "/home/xxxx/android-ndk-r10e/",
api_level = 19,
)
```
Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md)
Additional details on building with Android can be found [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/README.md).
### Build the source code
Run bazel with the following command to build the demo.
Build the demo app:
bazel build --cxxopt='--std=c++11' //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo
```
bazel build --cxxopt=--std=c++11 //tensorflow/contrib/lite/java/demo/app/src/main:TfLiteCameraDemo
```
### Note
Currently, we only support building the Android demo app within a Python 2
environment (due to a Bazel bug).
### More about the demo
The demo is resizing each camera image frame to (224 width * 224 height) to match the quantized Mobilenet model being used. The resized image is converted into a ByteBuffer row by row of size 1 * 224 * 224 * 3 bytes, where 1 is the number of images in a batch 224 * 224 is the width and height of the image 3 bytes represents three colors of a pixel. This demo uses the TensorFlow Lite Java inference API for models which take a single input and provide a single output. This outputs a two-dimensional array, with the first dimension being the category index and the second dimension being the confidence of classification. The Mobilenet model has 1001 unique categories and the app sorts the probabilities of all the categories and displays the top three. The Mobilenet quantized model is bundled within the assets directory of the app.
@ -95,7 +101,7 @@ The demo is resizing each camera image frame to (224 width * 224 height) to matc
[On Device Smart Reply](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html) is an on-device model which provides one-touch replies for an incoming text message by suggesting contextually relevant messages. The model is built specifically for memory constrained devices such as watches & phones and it has been successfully used to surface [Smart Replies on Android Wear](https://research.googleblog.com/2017/02/on-device-machine-intelligence.html). Note that this model only works on Android as of now.
These pre-trained models can be downloaded from [here](models.md).
These pre-trained models can be downloaded from [here](g3doc/models.md).
### Retrain Inception-V3 or MobileNet for a custom data set
The above pre-trained models have been trained on the ImageNet data set, which consists of 1000 predefined classes. A model will need to be re-trained if these classes are not relevant or useful for a given use case. This technique is called transfer learning, which starts with a model that has been already trained on a problem and will then be retrained on a similar problem. Deep learning from scratch can take days, but transfer learning can be done fairly quickly. In order to do this, a developer will need to generate their custom data set labeled with the relevant classes.
@ -104,7 +110,7 @@ The [TensorFlow for Poets](https://codelabs.developers.google.com/codelabs/tenso
### Train a custom model
A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlows Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model.
A developer may choose to train a custom model using Tensorflow. TensorFlow documentation has [several tutorials](https://www.tensorflow.org/tutorials/) for building and training models. If the user has written a model using TensorFlow's Slim Framework the first step is to export this to a GraphDef file. This is necessary because Slim does not store the model structure outside the code, so to communicate with other parts of the framework it needs to be exported. Documentation for the export can be found [here](https://github.com/tensorflow/models/tree/master/research/slim#Export). The output of this step will be a .pb file for the custom model.
TensorFlow Lite currently supports a subset of TensorFlow operators. Please refer to [this document](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for details of supported operators and their usage. This
set will continue to expand in future releases of Tensorflow Lite.
@ -128,9 +134,9 @@ Since we employ several formats, the following definitions may be useful:
- TensorFlow lite model (.lite) - a serialized flatbuffer, containing TensorFlow lite operators and Tensors for the TensorFlow lite interpreter. This is most analogous to TensorFlow frozen GraphDefs.
### Freeze Graph
To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as “freezing” the graph.
To use this .pb GraphDef file within TensorFlow Lite, the application developer will need checkpoints containing trained weight parameters. The .pb contains only the structure of the graph. The process of merging the checkpoint values with the graph structure is known as "freezing" the graph.
The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)
The developer should know where the checkpoints folder is present or checkpoints can also be downloaded for a pre-trained model (Example: Here is a link to the [MobileNets](https://github.com/tensorflow/models/blob/master/research/slim/nets/mobilenet_v1.md)).
Graph freezing can be done using the command below (and modifying the arguments appropriately)
@ -155,7 +161,7 @@ Here is a sample command line to convert the frozen Graphdef to '.lite' format f
bazel build tensorflow/contrib/lite/toco:toco
bazel-bin/tensorflow/contrib/lite/toco/toco -- \
--input_file=(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
--input_file=$(pwd)/mobilenet_v1_1.0_224/frozen_graph.pb \
--input_format=TENSORFLOW_GRAPHDEF --output_format=TFLITE \
--output_file=/tmp/mobilenet_v1_1.0_224.lite --inference_type=FLOAT \
--input_type=FLOAT --input_arrays=input \
@ -183,18 +189,18 @@ with tf.Session() as sess:
```
For detailed instructions on how to use the Tensorflow Optimizing Converter, please see [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/toco/g3doc/cmdline_examples.md).
You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/tf_ops_compatibility.md) for troubleshooting help. If that doesnt help, please file an [issue](https://github.com/tensorflow/tensorflow/issues).
You may refer to the [Ops compatibility guide](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/tf_ops_compatibility.md) for troubleshooting help. If that doesn't help, please file an [issue](https://github.com/tensorflow/tensorflow/issues).
## Step 3. Use the TensorFlow Lite model for inference in a mobile app
After completion of Step 2 the developer should have a .lite model.
### For Android
Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/TensorFlow/TensorFlow/blob/master/TensorFlow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
Because Android apps need to be written in Java, and core TensorFlow is in C++, a JNI library is provided to interface between the two. Its interface is aimed only at inference, so it provides the ability to load a graph, set up inputs, and run the model to calculate particular outputs. The full documentation for the set of methods can be seen [here](https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/lite/g3doc/). The demo app is also open sourced on [github](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app).
The [demo app] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so its a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
The [demo app](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/java/demo/app) uses this interface, so it's a good place to look for example usage. You can also download the prebuilt binary [here](http://download.tensorflow.org/deps/tflite/TfLiteCameraDemo.apk).
Note that youd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build).
Note that you'd need to follow instructions for installing TensorFlow on Android, setting up bazel and Android Studio outlined [here](https://www.tensorflow.org/mobile/android_build).
### For iOS
Follow the documentation [here](https://github.com/TensorFlow/TensorFlow/blob/master/TensorFlow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app.
Follow the documentation [here](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/g3doc/ios.md) to get integrate a TFLite model into your app.

View File

@ -0,0 +1,31 @@
#!/bin/bash -x
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=x86_64 -j 8
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=i386 -j 8
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7 -j 8
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=armv7s -j 8
make -f tensorflow/contrib/lite/Makefile TARGET=IOS IOS_ARCH=arm64 -j 8
lipo \
tensorflow/contrib/lite/gen/lib/ios_x86_64/libtensorflow-lite.a \
tensorflow/contrib/lite/gen/lib/ios_i386/libtensorflow-lite.a \
tensorflow/contrib/lite/gen/lib/ios_armv7/libtensorflow-lite.a \
tensorflow/contrib/lite/gen/lib/ios_armv7s/libtensorflow-lite.a \
tensorflow/contrib/lite/gen/lib/ios_arm64/libtensorflow-lite.a \
-create \
-output tensorflow/contrib/lite/gen/lib/libtensorflow-lite.a

View File

@ -0,0 +1,99 @@
#!/bin/bash
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
set -e
DOWNLOADS_DIR=tensorflow/contrib/lite/downloads
BZL_FILE_PATH=tensorflow/workspace.bzl
EIGEN_URL="$(grep -o 'http.*bitbucket.org/eigen/eigen/get/.*tar\.gz' "${BZL_FILE_PATH}" | grep -v bazel-mirror | head -n1)"
GEMMLOWP_URL="$(grep -o 'https://mirror.bazel.build/github.com/google/gemmlowp/.*zip' "${BZL_FILE_PATH}" | head -n1)"
GOOGLETEST_URL="https://github.com/google/googletest/archive/release-1.8.0.tar.gz"
ABSL_URL="$(grep -o 'https://github.com/abseil/abseil-cpp/.*tar.gz' "${BZL_FILE_PATH}" | head -n1)"
NEON_2_SSE_URL="https://github.com/intel/ARM_NEON_2_x86_SSE/archive/master.zip"
FARMHASH_URL="https://mirror.bazel.build/github.com/google/farmhash/archive/816a4ae622e964763ca0862d9dbd19324a1eaf45.tar.gz"
FLATBUFFERS_URL="https://github.com/google/flatbuffers/archive/master.zip"
MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_1.0_224_ios_lite_float_2017_11_08.zip"
QUANTIZED_MODELS_URL="https://storage.googleapis.com/download.tensorflow.org/models/tflite/mobilenet_v1_224_android_quant_2017_11_08.zip"
# TODO(petewarden): Some new code in Eigen triggers a clang bug with iOS arm64,
# so work around it by patching the source.
replace_by_sed() {
local regex="${1}"
shift
# Detect the version of sed by the return value of "--version" flag. GNU-sed
# supports "--version" while BSD-sed doesn't.
if ! sed --version >/dev/null 2>&1; then
# BSD-sed.
sed -i '' -e "${regex}" "$@"
else
# GNU-sed.
sed -i -e "${regex}" "$@"
fi
}
download_and_extract() {
local usage="Usage: download_and_extract URL DIR"
local url="${1:?${usage}}"
local dir="${2:?${usage}}"
echo "downloading ${url}" >&2
mkdir -p "${dir}"
if [[ "${url}" == *gz ]]; then
curl -Ls "${url}" | tar -C "${dir}" --strip-components=1 -xz
elif [[ "${url}" == *zip ]]; then
tempdir=$(mktemp -d)
tempdir2=$(mktemp -d)
curl -L ${url} > ${tempdir}/zipped.zip
unzip ${tempdir}/zipped.zip -d ${tempdir2}
# If the zip file contains nested directories, extract the files from the
# inner directory.
if ls ${tempdir2}/*/* 1> /dev/null 2>&1; then
# unzip has no strip components, so unzip to a temp dir, and move the
# files we want from the tempdir to destination.
cp -R ${tempdir2}/*/* ${dir}/
else
cp -R ${tempdir2}/* ${dir}/
fi
rm -rf ${tempdir2} ${tempdir}
fi
# Delete any potential BUILD files, which would interfere with Bazel builds.
find "${dir}" -type f -name '*BUILD' -delete
}
download_and_extract "${EIGEN_URL}" "${DOWNLOADS_DIR}/eigen"
download_and_extract "${GEMMLOWP_URL}" "${DOWNLOADS_DIR}/gemmlowp"
download_and_extract "${GOOGLETEST_URL}" "${DOWNLOADS_DIR}/googletest"
download_and_extract "${ABSL_URL}" "${DOWNLOADS_DIR}/absl"
download_and_extract "${NEON_2_SSE_URL}" "${DOWNLOADS_DIR}/neon_2_sse"
download_and_extract "${FARMHASH_URL}" "${DOWNLOADS_DIR}/farmhash"
download_and_extract "${FLATBUFFERS_URL}" "${DOWNLOADS_DIR}/flatbuffers"
download_and_extract "${MODELS_URL}" "${DOWNLOADS_DIR}/models"
download_and_extract "${QUANTIZED_MODELS_URL}" "${DOWNLOADS_DIR}/quantized_models"
replace_by_sed 's#static uint32x4_t p4ui_CONJ_XOR = vld1q_u32( conj_XOR_DATA );#static uint32x4_t p4ui_CONJ_XOR; // = vld1q_u32( conj_XOR_DATA ); - Removed by script#' \
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
replace_by_sed 's#static uint32x2_t p2ui_CONJ_XOR = vld1_u32( conj_XOR_DATA );#static uint32x2_t p2ui_CONJ_XOR;// = vld1_u32( conj_XOR_DATA ); - Removed by scripts#' \
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
replace_by_sed 's#static uint64x2_t p2ul_CONJ_XOR = vld1q_u64( p2ul_conj_XOR_DATA );#static uint64x2_t p2ul_CONJ_XOR;// = vld1q_u64( p2ul_conj_XOR_DATA ); - Removed by script#' \
"${DOWNLOADS_DIR}/eigen/Eigen/src/Core/arch/NEON/Complex.h"
cp ${DOWNLOADS_DIR}/models/models/* tensorflow/contrib/lite/examples/ios/simple/data/
cp ${DOWNLOADS_DIR}/quantized_models/* tensorflow/contrib/lite/examples/ios/camera/data/
echo "download_dependencies.sh completed successfully." >&2

View File

@ -0,0 +1,2 @@
/data/*.txt
/data/*.tflite

View File

@ -0,0 +1,21 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <UIKit/UIKit.h>
@interface CameraExampleAppDelegate : UIResponder<UIApplicationDelegate>
@property(strong, nonatomic) UIWindow* window;
@end

View File

@ -0,0 +1,44 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "CameraExampleAppDelegate.h"
@implementation CameraExampleAppDelegate
@synthesize window = _window;
- (BOOL)application:(UIApplication *)application
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
[self.window makeKeyAndVisible];
return YES;
}
- (void)applicationWillResignActive:(UIApplication *)application {
[[UIApplication sharedApplication] setIdleTimerDisabled:NO];
}
- (void)applicationDidEnterBackground:(UIApplication *)application {
}
- (void)applicationWillEnterForeground:(UIApplication *)application {
}
- (void)applicationDidBecomeActive:(UIApplication *)application {
[[UIApplication sharedApplication] setIdleTimerDisabled:YES];
}
- (void)applicationWillTerminate:(UIApplication *)application {
}
@end

View File

@ -0,0 +1,48 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <AVFoundation/AVFoundation.h>
#import <UIKit/UIKit.h>
#include <vector>
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
@interface CameraExampleViewController
: UIViewController<UIGestureRecognizerDelegate, AVCaptureVideoDataOutputSampleBufferDelegate> {
IBOutlet UIView* previewView;
AVCaptureVideoPreviewLayer* previewLayer;
AVCaptureVideoDataOutput* videoDataOutput;
dispatch_queue_t videoDataOutputQueue;
UIView* flashView;
BOOL isUsingFrontFacingCamera;
NSMutableDictionary* oldPredictionValues;
NSMutableArray* labelLayers;
AVCaptureSession* session;
std::vector<std::string> labels;
std::unique_ptr<tflite::FlatBufferModel> model;
tflite::ops::builtin::BuiltinOpResolver resolver;
std::unique_ptr<tflite::Interpreter> interpreter;
double total_latency;
int total_count;
}
@property(strong, nonatomic) CATextLayer* predictionTextLayer;
- (IBAction)takePicture:(id)sender;
- (IBAction)switchCameras:(id)sender;
@end

View File

@ -0,0 +1,506 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "CameraExampleViewController.h"
#import <AssertMacros.h>
#import <AssetsLibrary/AssetsLibrary.h>
#import <CoreImage/CoreImage.h>
#import <ImageIO/ImageIO.h>
#include <sys/time.h>
#include <fstream>
#include <iostream>
#include <queue>
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
#define LOG(x) std::cerr
// If you have your own model, modify this to the file name, and make sure
// you've added the file to your app resources too.
static NSString* model_file_name = @"mobilenet_quant_v1_224";
static NSString* model_file_type = @"tflite";
// If you have your own model, point this to the labels file.
static NSString* labels_file_name = @"labels";
static NSString* labels_file_type = @"txt";
// These dimensions need to match those the model was trained with.
static const int wanted_input_width = 224;
static const int wanted_input_height = 224;
static const int wanted_input_channels = 3;
static NSString* FilePathForResourceName(NSString* name, NSString* extension) {
NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String]
<< "' in bundle.";
}
return file_path;
}
static void LoadLabels(NSString* file_name, NSString* file_type,
std::vector<std::string>* label_strings) {
NSString* labels_path = FilePathForResourceName(file_name, file_type);
if (!labels_path) {
LOG(ERROR) << "Failed to find model proto at" << [file_name UTF8String]
<< [file_type UTF8String];
}
std::ifstream t;
t.open([labels_path UTF8String]);
std::string line;
while (t) {
std::getline(t, line);
label_strings->push_back(line);
}
t.close();
}
// Returns the top N confidence values over threshold in the provided vector,
// sorted by confidence in descending order.
static void GetTopN(const uint8_t* prediction, const int prediction_size, const int num_results,
const float threshold, std::vector<std::pair<float, int>>* top_results) {
// Will contain top N results in ascending order.
std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int>>,
std::greater<std::pair<float, int>>>
top_result_pq;
const long count = prediction_size;
for (int i = 0; i < count; ++i) {
const float value = prediction[i] / 255.0;
// Only add it if it beats the threshold and has a chance at being in
// the top N.
if (value < threshold) {
continue;
}
top_result_pq.push(std::pair<float, int>(value, i));
// If at capacity, kick the smallest value out.
if (top_result_pq.size() > num_results) {
top_result_pq.pop();
}
}
// Copy to output vector and reverse into descending order.
while (!top_result_pq.empty()) {
top_results->push_back(top_result_pq.top());
top_result_pq.pop();
}
std::reverse(top_results->begin(), top_results->end());
}
@interface CameraExampleViewController (InternalMethods)
- (void)setupAVCapture;
- (void)teardownAVCapture;
@end
@implementation CameraExampleViewController
- (void)setupAVCapture {
NSError* error = nil;
session = [AVCaptureSession new];
if ([[UIDevice currentDevice] userInterfaceIdiom] == UIUserInterfaceIdiomPhone)
[session setSessionPreset:AVCaptureSessionPreset640x480];
else
[session setSessionPreset:AVCaptureSessionPresetPhoto];
AVCaptureDevice* device = [AVCaptureDevice defaultDeviceWithMediaType:AVMediaTypeVideo];
AVCaptureDeviceInput* deviceInput =
[AVCaptureDeviceInput deviceInputWithDevice:device error:&error];
assert(error == nil);
if ([session canAddInput:deviceInput]) [session addInput:deviceInput];
videoDataOutput = [AVCaptureVideoDataOutput new];
NSDictionary* rgbOutputSettings =
[NSDictionary dictionaryWithObject:[NSNumber numberWithInt:kCMPixelFormat_32BGRA]
forKey:(id)kCVPixelBufferPixelFormatTypeKey];
[videoDataOutput setVideoSettings:rgbOutputSettings];
[videoDataOutput setAlwaysDiscardsLateVideoFrames:YES];
videoDataOutputQueue = dispatch_queue_create("VideoDataOutputQueue", DISPATCH_QUEUE_SERIAL);
[videoDataOutput setSampleBufferDelegate:self queue:videoDataOutputQueue];
if ([session canAddOutput:videoDataOutput]) [session addOutput:videoDataOutput];
[[videoDataOutput connectionWithMediaType:AVMediaTypeVideo] setEnabled:YES];
previewLayer = [[AVCaptureVideoPreviewLayer alloc] initWithSession:session];
[previewLayer setBackgroundColor:[[UIColor blackColor] CGColor]];
[previewLayer setVideoGravity:AVLayerVideoGravityResizeAspect];
CALayer* rootLayer = [previewView layer];
[rootLayer setMasksToBounds:YES];
[previewLayer setFrame:[rootLayer bounds]];
[rootLayer addSublayer:previewLayer];
[session startRunning];
if (error) {
NSString* title = [NSString stringWithFormat:@"Failed with error %d", (int)[error code]];
UIAlertController* alertController =
[UIAlertController alertControllerWithTitle:title
message:[error localizedDescription]
preferredStyle:UIAlertControllerStyleAlert];
UIAlertAction* dismiss =
[UIAlertAction actionWithTitle:@"Dismiss" style:UIAlertActionStyleDefault handler:nil];
[alertController addAction:dismiss];
[self presentViewController:alertController animated:YES completion:nil];
[self teardownAVCapture];
}
}
- (void)teardownAVCapture {
[previewLayer removeFromSuperlayer];
}
- (AVCaptureVideoOrientation)avOrientationForDeviceOrientation:
(UIDeviceOrientation)deviceOrientation {
AVCaptureVideoOrientation result = (AVCaptureVideoOrientation)(deviceOrientation);
if (deviceOrientation == UIDeviceOrientationLandscapeLeft)
result = AVCaptureVideoOrientationLandscapeRight;
else if (deviceOrientation == UIDeviceOrientationLandscapeRight)
result = AVCaptureVideoOrientationLandscapeLeft;
return result;
}
- (IBAction)takePicture:(id)sender {
if ([session isRunning]) {
[session stopRunning];
[sender setTitle:@"Continue" forState:UIControlStateNormal];
flashView = [[UIView alloc] initWithFrame:[previewView frame]];
[flashView setBackgroundColor:[UIColor whiteColor]];
[flashView setAlpha:0.f];
[[[self view] window] addSubview:flashView];
[UIView animateWithDuration:.2f
animations:^{
[flashView setAlpha:1.f];
}
completion:^(BOOL finished) {
[UIView animateWithDuration:.2f
animations:^{
[flashView setAlpha:0.f];
}
completion:^(BOOL finished) {
[flashView removeFromSuperview];
flashView = nil;
}];
}];
} else {
[session startRunning];
[sender setTitle:@"Freeze Frame" forState:UIControlStateNormal];
}
}
- (void)captureOutput:(AVCaptureOutput*)captureOutput
didOutputSampleBuffer:(CMSampleBufferRef)sampleBuffer
fromConnection:(AVCaptureConnection*)connection {
CVPixelBufferRef pixelBuffer = CMSampleBufferGetImageBuffer(sampleBuffer);
CFRetain(pixelBuffer);
[self runModelOnFrame:pixelBuffer];
CFRelease(pixelBuffer);
}
- (void)runModelOnFrame:(CVPixelBufferRef)pixelBuffer {
assert(pixelBuffer != NULL);
OSType sourcePixelFormat = CVPixelBufferGetPixelFormatType(pixelBuffer);
int doReverseChannels;
if (kCVPixelFormatType_32ARGB == sourcePixelFormat) {
doReverseChannels = 1;
} else if (kCVPixelFormatType_32BGRA == sourcePixelFormat) {
doReverseChannels = 0;
} else {
assert(false); // Unknown source format
}
const int sourceRowBytes = (int)CVPixelBufferGetBytesPerRow(pixelBuffer);
const int image_width = (int)CVPixelBufferGetWidth(pixelBuffer);
const int fullHeight = (int)CVPixelBufferGetHeight(pixelBuffer);
CVPixelBufferLockFlags unlockFlags = kNilOptions;
CVPixelBufferLockBaseAddress(pixelBuffer, unlockFlags);
unsigned char* sourceBaseAddr = (unsigned char*)(CVPixelBufferGetBaseAddress(pixelBuffer));
int image_height;
unsigned char* sourceStartAddr;
if (fullHeight <= image_width) {
image_height = fullHeight;
sourceStartAddr = sourceBaseAddr;
} else {
image_height = image_width;
const int marginY = ((fullHeight - image_width) / 2);
sourceStartAddr = (sourceBaseAddr + (marginY * sourceRowBytes));
}
const int image_channels = 4;
assert(image_channels >= wanted_input_channels);
uint8_t* in = sourceStartAddr;
int input = interpreter->inputs()[0];
uint8_t* out = interpreter->typed_tensor<uint8_t>(input);
for (int y = 0; y < wanted_input_height; ++y) {
uint8_t* out_row = out + (y * wanted_input_width * wanted_input_channels);
for (int x = 0; x < wanted_input_width; ++x) {
const int in_x = (y * image_width) / wanted_input_width;
const int in_y = (x * image_height) / wanted_input_height;
uint8_t* in_pixel = in + (in_y * image_width * image_channels) + (in_x * image_channels);
uint8_t* out_pixel = out_row + (x * wanted_input_channels);
for (int c = 0; c < wanted_input_channels; ++c) {
out_pixel[c] = in_pixel[c];
}
}
}
double startTimestamp = [[NSDate new] timeIntervalSince1970];
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to invoke!";
}
double endTimestamp = [[NSDate new] timeIntervalSince1970];
total_latency += (endTimestamp - startTimestamp);
total_count += 1;
NSLog(@"Time: %.4lf, avg: %.4lf, count: %d", endTimestamp - startTimestamp,
total_latency / total_count, total_count);
const int output_size = 1000;
const int kNumResults = 5;
const float kThreshold = 0.1f;
std::vector<std::pair<float, int>> top_results;
uint8_t* output = interpreter->typed_output_tensor<uint8_t>(0);
GetTopN(output, output_size, kNumResults, kThreshold, &top_results);
NSMutableDictionary* newValues = [NSMutableDictionary dictionary];
for (const auto& result : top_results) {
const float confidence = result.first;
const int index = result.second;
NSString* labelObject = [NSString stringWithUTF8String:labels[index].c_str()];
NSNumber* valueObject = [NSNumber numberWithFloat:confidence];
[newValues setObject:valueObject forKey:labelObject];
}
dispatch_async(dispatch_get_main_queue(), ^(void) {
[self setPredictionValues:newValues];
});
CVPixelBufferUnlockBaseAddress(pixelBuffer, unlockFlags);
CVPixelBufferUnlockBaseAddress(pixelBuffer, 0);
}
- (void)dealloc {
[self teardownAVCapture];
}
- (void)didReceiveMemoryWarning {
[super didReceiveMemoryWarning];
}
- (void)viewDidLoad {
[super viewDidLoad];
labelLayers = [[NSMutableArray alloc] init];
oldPredictionValues = [[NSMutableDictionary alloc] init];
NSString* graph_path = FilePathForResourceName(model_file_name, @"tflite");
model = tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]);
if (!model) {
LOG(FATAL) << "Failed to mmap model " << graph_path;
}
LOG(INFO) << "Loaded model " << graph_path;
model->error_reporter();
LOG(INFO) << "resolved reporter";
tflite::ops::builtin::BuiltinOpResolver resolver;
LoadLabels(labels_file_name, labels_file_type, &labels);
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter) {
LOG(FATAL) << "Failed to construct interpreter";
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Failed to allocate tensors!";
}
[self setupAVCapture];
}
- (void)viewDidUnload {
[super viewDidUnload];
}
- (void)viewWillAppear:(BOOL)animated {
[super viewWillAppear:animated];
}
- (void)viewDidAppear:(BOOL)animated {
[super viewDidAppear:animated];
}
- (void)viewWillDisappear:(BOOL)animated {
[super viewWillDisappear:animated];
}
- (void)viewDidDisappear:(BOOL)animated {
[super viewDidDisappear:animated];
}
- (BOOL)shouldAutorotateToInterfaceOrientation:(UIInterfaceOrientation)interfaceOrientation {
return (interfaceOrientation == UIInterfaceOrientationPortrait);
}
- (BOOL)prefersStatusBarHidden {
return YES;
}
- (void)setPredictionValues:(NSDictionary*)newValues {
const float decayValue = 0.75f;
const float updateValue = 0.25f;
const float minimumThreshold = 0.01f;
NSMutableDictionary* decayedPredictionValues = [[NSMutableDictionary alloc] init];
for (NSString* label in oldPredictionValues) {
NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label];
const float oldPredictionValue = [oldPredictionValueObject floatValue];
const float decayedPredictionValue = (oldPredictionValue * decayValue);
if (decayedPredictionValue > minimumThreshold) {
NSNumber* decayedPredictionValueObject = [NSNumber numberWithFloat:decayedPredictionValue];
[decayedPredictionValues setObject:decayedPredictionValueObject forKey:label];
}
}
oldPredictionValues = decayedPredictionValues;
for (NSString* label in newValues) {
NSNumber* newPredictionValueObject = [newValues objectForKey:label];
NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label];
if (!oldPredictionValueObject) {
oldPredictionValueObject = [NSNumber numberWithFloat:0.0f];
}
const float newPredictionValue = [newPredictionValueObject floatValue];
const float oldPredictionValue = [oldPredictionValueObject floatValue];
const float updatedPredictionValue = (oldPredictionValue + (newPredictionValue * updateValue));
NSNumber* updatedPredictionValueObject = [NSNumber numberWithFloat:updatedPredictionValue];
[oldPredictionValues setObject:updatedPredictionValueObject forKey:label];
}
NSArray* candidateLabels = [NSMutableArray array];
for (NSString* label in oldPredictionValues) {
NSNumber* oldPredictionValueObject = [oldPredictionValues objectForKey:label];
const float oldPredictionValue = [oldPredictionValueObject floatValue];
if (oldPredictionValue > 0.05f) {
NSDictionary* entry = @{@"label" : label, @"value" : oldPredictionValueObject};
candidateLabels = [candidateLabels arrayByAddingObject:entry];
}
}
NSSortDescriptor* sort = [NSSortDescriptor sortDescriptorWithKey:@"value" ascending:NO];
NSArray* sortedLabels =
[candidateLabels sortedArrayUsingDescriptors:[NSArray arrayWithObject:sort]];
const float leftMargin = 10.0f;
const float topMargin = 10.0f;
const float valueWidth = 48.0f;
const float valueHeight = 18.0f;
const float labelWidth = 246.0f;
const float labelHeight = 18.0f;
const float labelMarginX = 5.0f;
const float labelMarginY = 5.0f;
[self removeAllLabelLayers];
int labelCount = 0;
for (NSDictionary* entry in sortedLabels) {
NSString* label = [entry objectForKey:@"label"];
NSNumber* valueObject = [entry objectForKey:@"value"];
const float value = [valueObject floatValue];
const float originY = topMargin + ((labelHeight + labelMarginY) * labelCount);
const int valuePercentage = (int)roundf(value * 100.0f);
const float valueOriginX = leftMargin;
NSString* valueText = [NSString stringWithFormat:@"%d%%", valuePercentage];
[self addLabelLayerWithText:valueText
originX:valueOriginX
originY:originY
width:valueWidth
height:valueHeight
alignment:kCAAlignmentRight];
const float labelOriginX = (leftMargin + valueWidth + labelMarginX);
[self addLabelLayerWithText:[label capitalizedString]
originX:labelOriginX
originY:originY
width:labelWidth
height:labelHeight
alignment:kCAAlignmentLeft];
labelCount += 1;
if (labelCount > 4) {
break;
}
}
}
- (void)removeAllLabelLayers {
for (CATextLayer* layer in labelLayers) {
[layer removeFromSuperlayer];
}
[labelLayers removeAllObjects];
}
- (void)addLabelLayerWithText:(NSString*)text
originX:(float)originX
originY:(float)originY
width:(float)width
height:(float)height
alignment:(NSString*)alignment {
CFTypeRef font = (CFTypeRef) @"Menlo-Regular";
const float fontSize = 12.0;
const float marginSizeX = 5.0f;
const float marginSizeY = 2.0f;
const CGRect backgroundBounds = CGRectMake(originX, originY, width, height);
const CGRect textBounds = CGRectMake((originX + marginSizeX), (originY + marginSizeY),
(width - (marginSizeX * 2)), (height - (marginSizeY * 2)));
CATextLayer* background = [CATextLayer layer];
[background setBackgroundColor:[UIColor blackColor].CGColor];
[background setOpacity:0.5f];
[background setFrame:backgroundBounds];
background.cornerRadius = 5.0f;
[[self.view layer] addSublayer:background];
[labelLayers addObject:background];
CATextLayer* layer = [CATextLayer layer];
[layer setForegroundColor:[UIColor whiteColor].CGColor];
[layer setFrame:textBounds];
[layer setAlignmentMode:alignment];
[layer setWrapped:YES];
[layer setFont:font];
[layer setFontSize:fontSize];
layer.contentsScale = [[UIScreen mainScreen] scale];
[layer setString:text];
[[self.view layer] addSublayer:layer];
[labelLayers addObject:layer];
}
@end

View File

@ -0,0 +1,44 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleDisplayName</key>
<string>tflite_camera_example</string>
<key>CFBundleExecutable</key>
<string>${EXECUTABLE_NAME}</string>
<key>CFBundleIdentifier</key>
<string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
<string>${PRODUCT_NAME}</string>
<key>CFBundlePackageType</key>
<string>APPL</string>
<key>CFBundleShortVersionString</key>
<string>1.0</string>
<key>CFBundleSignature</key>
<string>????</string>
<key>CFBundleVersion</key>
<string>1.0</string>
<key>LSRequiresIPhoneOS</key>
<true/>
<key>NSCameraUsageDescription</key>
<string>Capture images to detect object</string>
<key>UIMainStoryboardFile</key>
<string>MainStoryboard_iPhone</string>
<key>UIRequiresFullScreen</key>
<true/>
<key>UIStatusBarHidden</key>
<true/>
<key>UISupportedInterfaceOrientations</key>
<array>
<string>UIInterfaceOrientationPortrait</string>
</array>
<key>UISupportedInterfaceOrientations~ipad</key>
<array>
<string>UIInterfaceOrientationPortrait</string>
</array>
</dict>
</plist>

View File

@ -0,0 +1,46 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<document type="com.apple.InterfaceBuilder3.CocoaTouch.Storyboard.XIB" version="3.0" toolsVersion="9531" systemVersion="15E65" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" initialViewController="2">
<dependencies>
<deployment identifier="iOS"/>
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="9529"/>
</dependencies>
<scenes>
<!--Camera Example View Controller-->
<scene sceneID="5">
<objects>
<viewController id="2" customClass="CameraExampleViewController" sceneMemberID="viewController">
<view key="view" contentMode="scaleToFill" id="3">
<rect key="frame" x="0.0" y="0.0" width="320" height="568"/>
<autoresizingMask key="autoresizingMask" flexibleMaxX="YES" flexibleMaxY="YES"/>
<subviews>
<view contentMode="scaleToFill" id="12">
<rect key="frame" x="0.0" y="0.0" width="320" height="522"/>
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
<color key="backgroundColor" white="1" alpha="1" colorSpace="custom" customColorSpace="calibratedWhite"/>
<gestureRecognizers/>
</view>
<button opaque="NO" contentMode="scaleToFill" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" lineBreakMode="middleTruncation" id="iD8-yH-eWH">
<rect key="frame" x="0.0" y="454" width="320" height="33"/>
<autoresizingMask key="autoresizingMask" flexibleMaxX="YES" flexibleMaxY="YES"/>
<color key="backgroundColor" red="0.0" green="0.0" blue="0.0" alpha="1" colorSpace="calibratedRGB"/>
<fontDescription key="fontDescription" name="Menlo-Regular" family="Menlo" pointSize="20"/>
<state key="normal" title="Freeze Frame">
<color key="titleColor" white="1" alpha="1" colorSpace="calibratedWhite"/>
<color key="titleShadowColor" white="0.5" alpha="1" colorSpace="calibratedWhite"/>
</state>
<connections>
<action selector="takePicture:" destination="2" eventType="touchUpInside" id="BTy-7E-XUS"/>
</connections>
</button>
</subviews>
<color key="backgroundColor" red="0.0" green="0.0" blue="0.0" alpha="1" colorSpace="calibratedRGB"/>
</view>
<connections>
<outlet property="previewView" destination="12" id="13"/>
</connections>
</viewController>
<placeholder placeholderIdentifier="IBFirstResponder" id="4" sceneMemberID="firstResponder"/>
</objects>
</scene>
</scenes>
</document>

View File

@ -0,0 +1,5 @@
platform :ios, '8.0'
inhibit_all_warnings!
target 'tflite_camera_example'
pod 'TensorFlow-experimental'

View File

@ -0,0 +1,28 @@
// Copyright 2017 The TensorFlow Authors. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <UIKit/UIKit.h>
#import "CameraExampleAppDelegate.h"
int main(int argc, char* argv[]) {
int retVal = 0;
@autoreleasepool {
retVal =
UIApplicationMain(argc, argv, nil, NSStringFromClass([CameraExampleAppDelegate class]));
}
return retVal;
}

View File

@ -0,0 +1,419 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 46;
objects = {
/* Begin PBXBuildFile section */
1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1C3C9DCA1ED3AB4200B8B5FA /* main.mm */; };
1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */ = {isa = PBXBuildFile; fileRef = 1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */; };
1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */; };
1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */; };
1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */; };
1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */; };
1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */ = {isa = PBXBuildFile; fileRef = 1CDB2D4D1ED3AA35007929E9 /* Info.plist */; };
54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */; };
AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = AC1F82641FBA3CBD0052BA77 /* labels.txt */; };
AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */; };
ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; };
1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
1C3C9DCA1ED3AB4200B8B5FA /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = "<group>"; };
1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tflite_camera_example.app; sourceTree = BUILT_PRODUCTS_DIR; };
1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.storyboard; path = MainStoryboard_iPhone.storyboard; sourceTree = "<group>"; };
1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; };
1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreMedia.framework; path = System/Library/Frameworks/CoreMedia.framework; sourceTree = SDKROOT; };
1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = AVFoundation.framework; path = System/Library/Frameworks/AVFoundation.framework; sourceTree = SDKROOT; };
1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleAppDelegate.h; sourceTree = "<group>"; };
1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.objc; path = CameraExampleAppDelegate.m; sourceTree = "<group>"; };
1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = CameraExampleViewController.h; sourceTree = "<group>"; };
1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = CameraExampleViewController.mm; sourceTree = "<group>"; };
1CDB2D4D1ED3AA35007929E9 /* Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = Info.plist; sourceTree = "<group>"; };
3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tflite_camera_example.a"; sourceTree = BUILT_PRODUCTS_DIR; };
3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.debug.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.debug.xcconfig"; sourceTree = "<group>"; };
55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */ = {isa = PBXFileReference; includeInIndex = 1; lastKnownFileType = text.xcconfig; name = "Pods-tflite_camera_example.release.xcconfig"; path = "Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example.release.xcconfig"; sourceTree = "<group>"; };
AC1F82641FBA3CBD0052BA77 /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = "<group>"; };
AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = "<group>"; };
ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_quant_v1_224.tflite; sourceTree = "<group>"; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
1C564C0A1ED3A92E00087306 /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
AC1F82691FBA3F930052BA77 /* libtensorflow-lite.a in Frameworks */,
1CB47D491ED3AD1700DF7666 /* AVFoundation.framework in Frameworks */,
1CA5EB931ED3ABFB00247A34 /* CoreMedia.framework in Frameworks */,
54DC6C3C5F734F3A58069F0C /* libPods-tflite_camera_example.a in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
24D7686C331131624F4454A0 /* Frameworks */ = {
isa = PBXGroup;
children = (
AC1F82681FBA3F930052BA77 /* libtensorflow-lite.a */,
1CB47D481ED3AD1700DF7666 /* AVFoundation.framework */,
1CA5EB921ED3ABFB00247A34 /* CoreMedia.framework */,
1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */,
1C0D73481ECCC41B008C1DAB /* CoreImage.framework */,
1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */,
3BA8BF92C84895BFE59D8236 /* libPods-tflite_camera_example.a */,
);
name = Frameworks;
sourceTree = "<group>";
};
3E9FC355632FB928EA23BEED /* Pods */ = {
isa = PBXGroup;
children = (
3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */,
55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */,
);
name = Pods;
sourceTree = "<group>";
};
591157921CF4011C00C31E3A = {
isa = PBXGroup;
children = (
1C99111B1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard */,
1C3C9DCA1ED3AB4200B8B5FA /* main.mm */,
1CDB2D4D1ED3AA35007929E9 /* Info.plist */,
1CDB2D421ED3A9CD007929E9 /* CameraExampleAppDelegate.h */,
1CDB2D431ED3A9CD007929E9 /* CameraExampleAppDelegate.m */,
1CDB2D441ED3A9CD007929E9 /* CameraExampleViewController.h */,
1CDB2D451ED3A9CD007929E9 /* CameraExampleViewController.mm */,
59A3CFF31CF4E68100C4259F /* data */,
5911579C1CF4011C00C31E3A /* Products */,
3E9FC355632FB928EA23BEED /* Pods */,
24D7686C331131624F4454A0 /* Frameworks */,
);
sourceTree = "<group>";
};
5911579C1CF4011C00C31E3A /* Products */ = {
isa = PBXGroup;
children = (
1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */,
);
name = Products;
sourceTree = "<group>";
};
59A3CFF31CF4E68100C4259F /* data */ = {
isa = PBXGroup;
children = (
ACA1A4C91FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite */,
AC1F82641FBA3CBD0052BA77 /* labels.txt */,
);
path = data;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
1C564C0C1ED3A92E00087306 /* tflite_camera_example */ = {
isa = PBXNativeTarget;
buildConfigurationList = 1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */;
buildPhases = (
66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */,
1C564C091ED3A92E00087306 /* Sources */,
1C564C0A1ED3A92E00087306 /* Frameworks */,
1C564C0B1ED3A92E00087306 /* Resources */,
00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */,
5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */,
);
buildRules = (
);
dependencies = (
);
name = tflite_camera_example;
productName = tflite_camera_example;
productReference = 1C564C0D1ED3A92E00087306 /* tflite_camera_example.app */;
productType = "com.apple.product-type.application";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
591157931CF4011C00C31E3A /* Project object */ = {
isa = PBXProject;
attributes = {
LastSwiftUpdateCheck = 0830;
LastUpgradeCheck = 0830;
ORGANIZATIONNAME = Google;
TargetAttributes = {
1C564C0C1ED3A92E00087306 = {
CreatedOnToolsVersion = 8.3.2;
DevelopmentTeam = EQHXZ8M8AV;
ProvisioningStyle = Automatic;
};
};
};
buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */;
compatibilityVersion = "Xcode 3.2";
developmentRegion = English;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 591157921CF4011C00C31E3A;
productRefGroup = 5911579C1CF4011C00C31E3A /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
1C564C0C1ED3A92E00087306 /* tflite_camera_example */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
1C564C0B1ED3A92E00087306 /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
ACA1A4CA1FBB6C28009B8D86 /* mobilenet_quant_v1_224.tflite in Resources */,
1C99111C1ED3B0E600A6BFB9 /* MainStoryboard_iPhone.storyboard in Resources */,
1CDB2D4E1ED3AA35007929E9 /* Info.plist in Resources */,
AC1F82661FBA3CBD0052BA77 /* labels.txt in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXShellScriptBuildPhase section */
00E875C3B066535AE6B77101 /* [CP] Embed Pods Frameworks */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputPaths = (
);
name = "[CP] Embed Pods Frameworks";
outputPaths = (
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-frameworks.sh\"\n";
showEnvVarsInLog = 0;
};
5C2D02120E3E5E09567AA946 /* [CP] Copy Pods Resources */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputPaths = (
);
name = "[CP] Copy Pods Resources";
outputPaths = (
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "\"${SRCROOT}/Pods/Target Support Files/Pods-tflite_camera_example/Pods-tflite_camera_example-resources.sh\"\n";
showEnvVarsInLog = 0;
};
66DAEAAEE9EF6550C3A061E0 /* [CP] Check Pods Manifest.lock */ = {
isa = PBXShellScriptBuildPhase;
buildActionMask = 2147483647;
files = (
);
inputPaths = (
"${PODS_PODFILE_DIR_PATH}/Podfile.lock",
"${PODS_ROOT}/Manifest.lock",
);
name = "[CP] Check Pods Manifest.lock";
outputPaths = (
"$(DERIVED_FILE_DIR)/Pods-tflite_camera_example-checkManifestLockResult.txt",
);
runOnlyForDeploymentPostprocessing = 0;
shellPath = /bin/sh;
shellScript = "diff \"${PODS_PODFILE_DIR_PATH}/Podfile.lock\" \"${PODS_ROOT}/Manifest.lock\" > /dev/null\nif [ $? != 0 ] ; then\n # print error to STDERR\n echo \"error: The sandbox is not in sync with the Podfile.lock. Run 'pod install' or update your CocoaPods installation.\" >&2\n exit 1\nfi\n# This output is used by Xcode 'outputs' to avoid re-running this script phase.\necho \"SUCCESS\" > \"${SCRIPT_OUTPUT_FILE_0}\"\n";
showEnvVarsInLog = 0;
};
/* End PBXShellScriptBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
1C564C091ED3A92E00087306 /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
1CDB2D4A1ED3A9CD007929E9 /* CameraExampleViewController.mm in Sources */,
1CDB2D491ED3A9CD007929E9 /* CameraExampleAppDelegate.m in Sources */,
1C3C9DCC1ED3AB4200B8B5FA /* main.mm in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
1C564C361ED3A92E00087306 /* Debug */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 3BC5BE4BBD09374D3E98F082 /* Pods-tflite_camera_example.debug.xcconfig */;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
DEVELOPMENT_TEAM = EQHXZ8M8AV;
INFOPLIST_FILE = Info.plist;
IPHONEOS_DEPLOYMENT_TARGET = 10.3;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example";
PRODUCT_NAME = "$(TARGET_NAME)";
SWIFT_ACTIVE_COMPILATION_CONDITIONS = DEBUG;
SWIFT_OPTIMIZATION_LEVEL = "-Onone";
SWIFT_VERSION = 3.0;
};
name = Debug;
};
1C564C371ED3A92E00087306 /* Release */ = {
isa = XCBuildConfiguration;
baseConfigurationReference = 55ED318E8D29C8AFEF03DF1E /* Pods-tflite_camera_example.release.xcconfig */;
buildSettings = {
ASSETCATALOG_COMPILER_APPICON_NAME = AppIcon;
CLANG_ANALYZER_NONNULL = YES;
CLANG_ANALYZER_NUMBER_OBJECT_CONVERSION = YES_AGGRESSIVE;
CLANG_WARN_DOCUMENTATION_COMMENTS = YES;
DEVELOPMENT_TEAM = EQHXZ8M8AV;
INFOPLIST_FILE = Info.plist;
IPHONEOS_DEPLOYMENT_TARGET = 10.3;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
PRODUCT_BUNDLE_IDENTIFIER = "com.pf.tf-camera-example";
PRODUCT_NAME = "$(TARGET_NAME)";
SWIFT_OPTIMIZATION_LEVEL = "-Owholemodule";
SWIFT_VERSION = 3.0;
};
name = Release;
};
591157B01CF4011D00C31E3A /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
CLANG_CXX_LIBRARY = "libc++";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
GCC_C_LANGUAGE_STANDARD = gnu99;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = (
"$(inherited)",
../../../../../../,
../../../downloads/flatbuffers/include/,
../../../downloads/eigen/,
../../../downloads/,
);
IPHONEOS_DEPLOYMENT_TARGET = 8.0;
LIBRARY_SEARCH_PATHS = ../../../gen/lib/;
MTL_ENABLE_DEBUG_INFO = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
591157B11CF4011D00C31E3A /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
CLANG_CXX_LIBRARY = "libc++";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
GCC_C_LANGUAGE_STANDARD = gnu99;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
HEADER_SEARCH_PATHS = (
"$(inherited)",
../../../../../../,
../../../downloads/flatbuffers/include/,
../../../downloads/eigen/,
../../../downloads/,
);
IPHONEOS_DEPLOYMENT_TARGET = 8.0;
LIBRARY_SEARCH_PATHS = ../../../gen/lib/;
MTL_ENABLE_DEBUG_INFO = NO;
SDKROOT = iphoneos;
TARGETED_DEVICE_FAMILY = "1,2";
VALIDATE_PRODUCT = YES;
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
1C564C351ED3A92E00087306 /* Build configuration list for PBXNativeTarget "tflite_camera_example" */ = {
isa = XCConfigurationList;
buildConfigurations = (
1C564C361ED3A92E00087306 /* Debug */,
1C564C371ED3A92E00087306 /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
591157961CF4011C00C31E3A /* Build configuration list for PBXProject "tflite_camera_example" */ = {
isa = XCConfigurationList;
buildConfigurations = (
591157B01CF4011D00C31E3A /* Debug */,
591157B11CF4011D00C31E3A /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
};
rootObject = 591157931CF4011C00C31E3A /* Project object */;
}

View File

@ -0,0 +1,21 @@
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <UIKit/UIKit.h>
@interface AppDelegate : UIResponder<UIApplicationDelegate>
@property(strong, nonatomic) UIWindow *window;
@end

View File

@ -0,0 +1,47 @@
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "AppDelegate.h"
#import "RunModelViewController.h"
@implementation AppDelegate
- (BOOL)application:(UIApplication *)application
didFinishLaunchingWithOptions:(NSDictionary *)launchOptions {
UITabBarController *bar = [[UITabBarController alloc] init];
[bar setViewControllers:@[ [[RunModelViewController alloc] init] ]];
bar.selectedIndex = 0;
self.window = [[UIWindow alloc] initWithFrame:[[UIScreen mainScreen] bounds]];
self.window.rootViewController = bar;
[self.window makeKeyAndVisible];
return YES;
}
- (void)applicationWillResignActive:(UIApplication *)application {
}
- (void)applicationDidEnterBackground:(UIApplication *)application {
}
- (void)applicationWillEnterForeground:(UIApplication *)application {
}
- (void)applicationDidBecomeActive:(UIApplication *)application {
}
- (void)applicationWillTerminate:(UIApplication *)application {
}
@end

View File

@ -0,0 +1,5 @@
platform :ios, '8.0'
inhibit_all_warnings!
target 'tf_simple_example'
pod 'TensorFlow-experimental'

View File

@ -0,0 +1,47 @@
<?xml version="1.0" encoding="UTF-8"?>
<!DOCTYPE plist PUBLIC "-//Apple//DTD PLIST 1.0//EN" "http://www.apple.com/DTDs/PropertyList-1.0.dtd">
<plist version="1.0">
<dict>
<key>CFBundleDevelopmentRegion</key>
<string>en</string>
<key>CFBundleDisplayName</key>
<string>tflite-simple-example</string>
<key>CFBundleExecutable</key>
<string>tf_simple_example</string>
<key>CFBundleIdentifier</key>
<string>$(PRODUCT_BUNDLE_IDENTIFIER)</string>
<key>CFBundleInfoDictionaryVersion</key>
<string>6.0</string>
<key>CFBundleName</key>
<string>ios-app</string>
<key>CFBundlePackageType</key>
<string>APPL</string>
<key>CFBundleShortVersionString</key>
<string>1.0</string>
<key>CFBundleSignature</key>
<string>????</string>
<key>CFBundleVersion</key>
<string>1.0</string>
<key>LSRequiresIPhoneOS</key>
<true/>
<key>UILaunchStoryboardName</key>
<string>RunModelViewController</string>
<key>UIRequiredDeviceCapabilities</key>
<array>
<string>armv7</string>
</array>
<key>UISupportedInterfaceOrientations</key>
<array>
<string>UIInterfaceOrientationPortrait</string>
<string>UIInterfaceOrientationLandscapeLeft</string>
<string>UIInterfaceOrientationLandscapeRight</string>
</array>
<key>UISupportedInterfaceOrientations~ipad</key>
<array>
<string>UIInterfaceOrientationPortrait</string>
<string>UIInterfaceOrientationPortraitUpsideDown</string>
<string>UIInterfaceOrientationLandscapeLeft</string>
<string>UIInterfaceOrientationLandscapeRight</string>
</array>
</dict>
</plist>

View File

@ -0,0 +1,24 @@
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <UIKit/UIKit.h>
@interface RunModelViewController : UIViewController
- (IBAction)getUrl:(id)sender;
@property(weak, nonatomic) IBOutlet UITextView *urlContentTextView;
@property(weak, nonatomic) IBOutlet UITextField *urlTextField;
@end

View File

@ -0,0 +1,221 @@
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import "RunModelViewController.h"
#include <pthread.h>
#include <unistd.h>
#include <fstream>
#include <iostream>
#include <queue>
#include <sstream>
#include <string>
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
#include "ios_image_load.h"
#define LOG(x) std::cerr
#define CHECK(x) \
if (!(x)) { \
LOG(ERROR) << #x << "failed"; \
exit(1); \
}
NSString* RunInferenceOnImage();
@interface RunModelViewController ()
@end
@implementation RunModelViewController {
}
- (IBAction)getUrl:(id)sender {
NSString* inference_result = RunInferenceOnImage();
self.urlContentTextView.text = inference_result;
}
@end
// Returns the top N confidence values over threshold in the provided vector,
// sorted by confidence in descending order.
static void GetTopN(const float* prediction, const int prediction_size, const int num_results,
const float threshold, std::vector<std::pair<float, int> >* top_results) {
// Will contain top N results in ascending order.
std::priority_queue<std::pair<float, int>, std::vector<std::pair<float, int> >,
std::greater<std::pair<float, int> > >
top_result_pq;
const long count = prediction_size;
for (int i = 0; i < count; ++i) {
const float value = prediction[i];
// Only add it if it beats the threshold and has a chance at being in
// the top N.
if (value < threshold) {
continue;
}
top_result_pq.push(std::pair<float, int>(value, i));
// If at capacity, kick the smallest value out.
if (top_result_pq.size() > num_results) {
top_result_pq.pop();
}
}
// Copy to output vector and reverse into descending order.
while (!top_result_pq.empty()) {
top_results->push_back(top_result_pq.top());
top_result_pq.pop();
}
std::reverse(top_results->begin(), top_results->end());
}
NSString* FilePathForResourceName(NSString* name, NSString* extension) {
NSString* file_path = [[NSBundle mainBundle] pathForResource:name ofType:extension];
if (file_path == NULL) {
LOG(FATAL) << "Couldn't find '" << [name UTF8String] << "." << [extension UTF8String]
<< "' in bundle.";
}
return file_path;
}
NSString* RunInferenceOnImage() {
std::string graph;
const int num_threads = 1;
std::string input_layer_type = "float";
std::vector<int> sizes = {1, 224, 224, 3};
NSString* graph_path = FilePathForResourceName(@"mobilenet_v1_1.0_224", @"tflite");
std::unique_ptr<tflite::FlatBufferModel> model(
tflite::FlatBufferModel::BuildFromFile([graph_path UTF8String]));
if (!model) {
LOG(FATAL) << "Failed to mmap model " << graph;
}
LOG(INFO) << "Loaded model " << graph;
model->error_reporter();
LOG(INFO) << "resolved reporter";
#ifdef TFLITE_CUSTOM_OPS_HEADER
tflite::MutableOpResolver resolver;
RegisterSelectedOps(&resolver);
#else
tflite::ops::builtin::BuiltinOpResolver resolver;
#endif
std::unique_ptr<tflite::Interpreter> interpreter;
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter) {
LOG(FATAL) << "Failed to construct interpreter";
}
if (num_threads != -1) {
interpreter->SetNumThreads(num_threads);
}
int input = interpreter->inputs()[0];
if (input_layer_type != "string") {
interpreter->ResizeInputTensor(input, sizes);
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Failed to allocate tensors!";
}
// Read the label list
NSString* labels_path = FilePathForResourceName(@"labels", @"txt");
std::vector<std::string> label_strings;
std::ifstream t;
t.open([labels_path UTF8String]);
std::string line;
while (t) {
std::getline(t, line);
label_strings.push_back(line);
}
t.close();
// Read the Grace Hopper image.
NSString* image_path = FilePathForResourceName(@"grace_hopper", @"jpg");
int image_width;
int image_height;
int image_channels;
std::vector<uint8_t> image_data =
LoadImageFromFile([image_path UTF8String], &image_width, &image_height, &image_channels);
const int wanted_width = 224;
const int wanted_height = 224;
const int wanted_channels = 3;
const float input_mean = 127.5f;
const float input_std = 127.5f;
assert(image_channels >= wanted_channels);
uint8_t* in = image_data.data();
float* out = interpreter->typed_tensor<float>(input);
for (int y = 0; y < wanted_height; ++y) {
const int in_y = (y * image_height) / wanted_height;
uint8_t* in_row = in + (in_y * image_width * image_channels);
float* out_row = out + (y * wanted_width * wanted_channels);
for (int x = 0; x < wanted_width; ++x) {
const int in_x = (x * image_width) / wanted_width;
uint8_t* in_pixel = in_row + (in_x * image_channels);
float* out_pixel = out_row + (x * wanted_channels);
for (int c = 0; c < wanted_channels; ++c) {
out_pixel[c] = (in_pixel[c] - input_mean) / input_std;
}
}
}
if (interpreter->Invoke() != kTfLiteOk) {
LOG(FATAL) << "Failed to invoke!";
}
float* output = interpreter->typed_output_tensor<float>(0);
const int output_size = 1000;
const int kNumResults = 5;
const float kThreshold = 0.1f;
std::vector<std::pair<float, int> > top_results;
GetTopN(output, output_size, kNumResults, kThreshold, &top_results);
std::stringstream ss;
ss.precision(3);
for (const auto& result : top_results) {
const float confidence = result.first;
const int index = result.second;
ss << index << " " << confidence << " ";
// Write out the result as a string
if (index < label_strings.size()) {
// just for safety: theoretically, the output is under 1000 unless there
// is some numerical issues leading to a wrong prediction.
ss << label_strings[index];
} else {
ss << "Prediction: " << index;
}
ss << "\n";
}
LOG(INFO) << "Predictions: " << ss.str();
std::string predictions = ss.str();
NSString* result = @"";
result = [NSString stringWithFormat:@"%@ - %s", result, predictions.c_str()];
return result;
}

View File

@ -0,0 +1,46 @@
<?xml version="1.0" encoding="UTF-8" standalone="no"?>
<document type="com.apple.InterfaceBuilder3.CocoaTouch.XIB" version="3.0" toolsVersion="9531" systemVersion="15D21" targetRuntime="iOS.CocoaTouch" propertyAccessControl="none" useAutolayout="YES">
<dependencies>
<plugIn identifier="com.apple.InterfaceBuilder.IBCocoaTouchPlugin" version="9529"/>
</dependencies>
<objects>
<placeholder placeholderIdentifier="IBFilesOwner" id="-1" userLabel="File's Owner" customClass="RunModelViewController">
<connections>
<outlet property="urlContentTextView" destination="quY-AK-ZCn" id="YjW-BO-1Ta"/>
<outlet property="urlTextField" destination="hPw-q5-vh5" id="wmc-b6-2CV"/>
<outlet property="view" destination="1" id="iHm-Rr-4wj"/>
</connections>
</placeholder>
<placeholder placeholderIdentifier="IBFirstResponder" id="-2" customClass="UIResponder"/>
<view contentMode="scaleToFill" id="1">
<rect key="frame" x="0.0" y="0.0" width="320" height="568"/>
<autoresizingMask key="autoresizingMask" widthSizable="YES" heightSizable="YES"/>
<subviews>
<textView clipsSubviews="YES" contentMode="scaleToFill" fixedFrame="YES" editable="NO" text="The results of running the model will appear here." selectable="NO" translatesAutoresizingMaskIntoConstraints="NO" id="quY-AK-ZCn">
<rect key="frame" x="40" y="99" width="240" height="168"/>
<color key="backgroundColor" white="1" alpha="1" colorSpace="calibratedWhite"/>
<fontDescription key="fontDescription" type="system" pointSize="14"/>
<textInputTraits key="textInputTraits" autocapitalizationType="sentences"/>
</textView>
<button opaque="NO" contentMode="scaleToFill" fixedFrame="YES" contentHorizontalAlignment="center" contentVerticalAlignment="center" buttonType="roundedRect" lineBreakMode="middleTruncation" translatesAutoresizingMaskIntoConstraints="NO" id="AAC-Bk-PCC">
<rect key="frame" x="76" y="37" width="168" height="30"/>
<color key="backgroundColor" white="0.33333333333333331" alpha="1" colorSpace="calibratedWhite"/>
<state key="normal" title="Run Model">
<color key="titleShadowColor" white="0.5" alpha="1" colorSpace="calibratedWhite"/>
</state>
<connections>
<action selector="getUrl:" destination="-1" eventType="touchUpInside" id="mdP-nK-k9T"/>
</connections>
</button>
</subviews>
<color key="backgroundColor" red="0.78314738357315861" green="0.79869981749999996" blue="0.56305065858222869" alpha="1" colorSpace="calibratedRGB"/>
</view>
<textField opaque="NO" clipsSubviews="YES" contentMode="scaleToFill" contentHorizontalAlignment="left" contentVerticalAlignment="center" text="http://localhost:8080" borderStyle="roundedRect" placeholder="Enter URL" minimumFontSize="17" id="hPw-q5-vh5">
<rect key="frame" x="0.0" y="0.0" width="280" height="30"/>
<autoresizingMask key="autoresizingMask" flexibleMaxX="YES" flexibleMaxY="YES"/>
<fontDescription key="fontDescription" type="system" pointSize="14"/>
<textInputTraits key="textInputTraits"/>
<point key="canvasLocation" x="795" y="44"/>
</textField>
</objects>
</document>

Binary file not shown.

After

Width:  |  Height:  |  Size: 72 KiB

View File

@ -0,0 +1,23 @@
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#ifndef TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
#define TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_
#include <vector>
std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width,
int* out_height, int* out_channels);
#endif // TENSORFLOW_EXAMPLES_IOS_IOS_IMAGE_LOAD_H_

View File

@ -0,0 +1,80 @@
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#include "ios_image_load.h"
#include <assert.h>
#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#import <CoreImage/CoreImage.h>
#import <ImageIO/ImageIO.h>
std::vector<uint8_t> LoadImageFromFile(const char* file_name, int* out_width, int* out_height,
int* out_channels) {
FILE* file_handle = fopen(file_name, "rb");
fseek(file_handle, 0, SEEK_END);
const size_t bytes_in_file = ftell(file_handle);
fseek(file_handle, 0, SEEK_SET);
std::vector<uint8_t> file_data(bytes_in_file);
fread(file_data.data(), 1, bytes_in_file, file_handle);
fclose(file_handle);
CFDataRef file_data_ref =
CFDataCreateWithBytesNoCopy(NULL, file_data.data(), bytes_in_file, kCFAllocatorNull);
CGDataProviderRef image_provider = CGDataProviderCreateWithCFData(file_data_ref);
const char* suffix = strrchr(file_name, '.');
if (!suffix || suffix == file_name) {
suffix = "";
}
CGImageRef image;
if (strcasecmp(suffix, ".png") == 0) {
image = CGImageCreateWithPNGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault);
} else if ((strcasecmp(suffix, ".jpg") == 0) || (strcasecmp(suffix, ".jpeg") == 0)) {
image =
CGImageCreateWithJPEGDataProvider(image_provider, NULL, true, kCGRenderingIntentDefault);
} else {
CFRelease(image_provider);
CFRelease(file_data_ref);
fprintf(stderr, "Unknown suffix for file '%s'\n", file_name);
*out_width = 0;
*out_height = 0;
*out_channels = 0;
return std::vector<uint8_t>();
}
const int width = (int)CGImageGetWidth(image);
const int height = (int)CGImageGetHeight(image);
const int channels = 4;
CGColorSpaceRef color_space = CGColorSpaceCreateDeviceRGB();
const int bytes_per_row = (width * channels);
const int bytes_in_image = (bytes_per_row * height);
std::vector<uint8_t> result(bytes_in_image);
const int bits_per_component = 8;
CGContextRef context =
CGBitmapContextCreate(result.data(), width, height, bits_per_component, bytes_per_row,
color_space, kCGImageAlphaPremultipliedLast | kCGBitmapByteOrder32Big);
CGColorSpaceRelease(color_space);
CGContextDrawImage(context, CGRectMake(0, 0, width, height), image);
CGContextRelease(context);
CFRelease(image);
CFRelease(image_provider);
CFRelease(file_data_ref);
*out_width = width;
*out_height = height;
*out_channels = channels;
return result;
}

View File

@ -0,0 +1,22 @@
// Copyright 2015 Google Inc. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
#import <UIKit/UIKit.h>
int main(int argc, char *argv[]) {
@autoreleasepool {
NSString *delegateClassName = @"AppDelegate";
return UIApplicationMain(argc, argv, nil, delegateClassName);
}
}

View File

@ -0,0 +1,359 @@
// !$*UTF8*$!
{
archiveVersion = 1;
classes = {
};
objectVersion = 46;
objects = {
/* Begin PBXBuildFile section */
1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */; };
1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */ = {isa = PBXBuildFile; fileRef = 1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */; };
594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */ = {isa = PBXBuildFile; fileRef = 594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */; };
594C14B11FB9037100EE8BFE /* labels.txt in Resources */ = {isa = PBXBuildFile; fileRef = 594C14AF1FB9037100EE8BFE /* labels.txt */; };
594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */ = {isa = PBXBuildFile; fileRef = 594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */; };
59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFF21CF4E68100C4259F /* AppDelegate.mm */; };
59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */ = {isa = PBXBuildFile; fileRef = 59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */; };
59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */; };
59A3D0091CF4E68100C4259F /* main.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFC1CF4E68100C4259F /* main.mm */; };
59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */ = {isa = PBXBuildFile; fileRef = 59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */; };
59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */ = {isa = PBXBuildFile; fileRef = 59A3D0001CF4E68100C4259F /* RunModelViewController.xib */; };
/* End PBXBuildFile section */
/* Begin PBXFileReference section */
1C0D73481ECCC41B008C1DAB /* CoreImage.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreImage.framework; path = System/Library/Frameworks/CoreImage.framework; sourceTree = SDKROOT; };
1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = CoreGraphics.framework; path = System/Library/Frameworks/CoreGraphics.framework; sourceTree = SDKROOT; };
1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */ = {isa = PBXFileReference; lastKnownFileType = wrapper.framework; name = UIKit.framework; path = System/Library/Frameworks/UIKit.framework; sourceTree = SDKROOT; };
5911579B1CF4011C00C31E3A /* tf_simple_example.app */ = {isa = PBXFileReference; explicitFileType = wrapper.application; includeInIndex = 0; path = tf_simple_example.app; sourceTree = BUILT_PRODUCTS_DIR; };
594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */ = {isa = PBXFileReference; lastKnownFileType = archive.ar; name = "libtensorflow-lite.a"; path = "../../../gen/lib/libtensorflow-lite.a"; sourceTree = "<group>"; };
594C14AF1FB9037100EE8BFE /* labels.txt */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text; path = labels.txt; sourceTree = "<group>"; };
594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */ = {isa = PBXFileReference; lastKnownFileType = file; path = mobilenet_v1_1.0_224.tflite; sourceTree = "<group>"; };
59A3CFF11CF4E68100C4259F /* AppDelegate.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = AppDelegate.h; sourceTree = "<group>"; };
59A3CFF21CF4E68100C4259F /* AppDelegate.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = AppDelegate.mm; sourceTree = "<group>"; };
59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */ = {isa = PBXFileReference; lastKnownFileType = image.jpeg; path = grace_hopper.jpg; sourceTree = "<group>"; };
59A3CFFA1CF4E68100C4259F /* ios_image_load.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = ios_image_load.h; sourceTree = "<group>"; };
59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = ios_image_load.mm; sourceTree = "<group>"; };
59A3CFFC1CF4E68100C4259F /* main.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = main.mm; sourceTree = "<group>"; };
59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = text.plist.xml; path = "RunModel-Info.plist"; sourceTree = "<group>"; };
59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.c.h; path = RunModelViewController.h; sourceTree = "<group>"; };
59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = sourcecode.cpp.objcpp; path = RunModelViewController.mm; sourceTree = "<group>"; };
59A3D0001CF4E68100C4259F /* RunModelViewController.xib */ = {isa = PBXFileReference; fileEncoding = 4; lastKnownFileType = file.xib; path = RunModelViewController.xib; sourceTree = "<group>"; };
73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */ = {isa = PBXFileReference; explicitFileType = archive.ar; includeInIndex = 0; path = "libPods-tf_simple_example.a"; sourceTree = BUILT_PRODUCTS_DIR; };
/* End PBXFileReference section */
/* Begin PBXFrameworksBuildPhase section */
591157981CF4011C00C31E3A /* Frameworks */ = {
isa = PBXFrameworksBuildPhase;
buildActionMask = 2147483647;
files = (
594C14AE1FB8F9B500EE8BFE /* libtensorflow-lite.a in Frameworks */,
1C0D734B1ECCC460008C1DAB /* CoreGraphics.framework in Frameworks */,
1CA45FFF1ECCC356002FA6A4 /* UIKit.framework in Frameworks */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXFrameworksBuildPhase section */
/* Begin PBXGroup section */
24D7686C331131624F4454A0 /* Frameworks */ = {
isa = PBXGroup;
children = (
594C14AD1FB8F9B500EE8BFE /* libtensorflow-lite.a */,
1C0D734A1ECCC460008C1DAB /* CoreGraphics.framework */,
1C0D73481ECCC41B008C1DAB /* CoreImage.framework */,
1CA45FFE1ECCC356002FA6A4 /* UIKit.framework */,
73DBC33C5DD9A526EE6D1EF2 /* libPods-tf_simple_example.a */,
);
name = Frameworks;
sourceTree = "<group>";
};
591157921CF4011C00C31E3A = {
isa = PBXGroup;
children = (
59A3CFF11CF4E68100C4259F /* AppDelegate.h */,
59A3CFF21CF4E68100C4259F /* AppDelegate.mm */,
59A3CFF31CF4E68100C4259F /* data */,
59A3CFFA1CF4E68100C4259F /* ios_image_load.h */,
59A3CFFB1CF4E68100C4259F /* ios_image_load.mm */,
59A3CFFC1CF4E68100C4259F /* main.mm */,
59A3CFFD1CF4E68100C4259F /* RunModel-Info.plist */,
59A3CFFE1CF4E68100C4259F /* RunModelViewController.h */,
59A3CFFF1CF4E68100C4259F /* RunModelViewController.mm */,
59A3D0001CF4E68100C4259F /* RunModelViewController.xib */,
5911579C1CF4011C00C31E3A /* Products */,
24D7686C331131624F4454A0 /* Frameworks */,
);
sourceTree = "<group>";
};
5911579C1CF4011C00C31E3A /* Products */ = {
isa = PBXGroup;
children = (
5911579B1CF4011C00C31E3A /* tf_simple_example.app */,
);
name = Products;
sourceTree = "<group>";
};
59A3CFF31CF4E68100C4259F /* data */ = {
isa = PBXGroup;
children = (
59A3CFF51CF4E68100C4259F /* grace_hopper.jpg */,
594C14AF1FB9037100EE8BFE /* labels.txt */,
594C14B01FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite */,
);
path = data;
sourceTree = "<group>";
};
/* End PBXGroup section */
/* Begin PBXNativeTarget section */
5911579A1CF4011C00C31E3A /* tf_simple_example */ = {
isa = PBXNativeTarget;
buildConfigurationList = 591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */;
buildPhases = (
591157971CF4011C00C31E3A /* Sources */,
591157981CF4011C00C31E3A /* Frameworks */,
591157991CF4011C00C31E3A /* Resources */,
);
buildRules = (
);
dependencies = (
);
name = tf_simple_example;
productName = tf_ios_makefile_example;
productReference = 5911579B1CF4011C00C31E3A /* tf_simple_example.app */;
productType = "com.apple.product-type.application";
};
/* End PBXNativeTarget section */
/* Begin PBXProject section */
591157931CF4011C00C31E3A /* Project object */ = {
isa = PBXProject;
attributes = {
LastUpgradeCheck = 0830;
ORGANIZATIONNAME = Google;
TargetAttributes = {
5911579A1CF4011C00C31E3A = {
CreatedOnToolsVersion = 7.2;
DevelopmentTeam = EQHXZ8M8AV;
ProvisioningStyle = Manual;
};
};
};
buildConfigurationList = 591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */;
compatibilityVersion = "Xcode 3.2";
developmentRegion = English;
hasScannedForEncodings = 0;
knownRegions = (
en,
Base,
);
mainGroup = 591157921CF4011C00C31E3A;
productRefGroup = 5911579C1CF4011C00C31E3A /* Products */;
projectDirPath = "";
projectRoot = "";
targets = (
5911579A1CF4011C00C31E3A /* tf_simple_example */,
);
};
/* End PBXProject section */
/* Begin PBXResourcesBuildPhase section */
591157991CF4011C00C31E3A /* Resources */ = {
isa = PBXResourcesBuildPhase;
buildActionMask = 2147483647;
files = (
59A3D00C1CF4E68100C4259F /* RunModelViewController.xib in Resources */,
594C14B11FB9037100EE8BFE /* labels.txt in Resources */,
59A3D0031CF4E68100C4259F /* grace_hopper.jpg in Resources */,
594C14B21FB9037100EE8BFE /* mobilenet_v1_1.0_224.tflite in Resources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXResourcesBuildPhase section */
/* Begin PBXSourcesBuildPhase section */
591157971CF4011C00C31E3A /* Sources */ = {
isa = PBXSourcesBuildPhase;
buildActionMask = 2147483647;
files = (
59A3D0091CF4E68100C4259F /* main.mm in Sources */,
59A3D0011CF4E68100C4259F /* AppDelegate.mm in Sources */,
59A3D00B1CF4E68100C4259F /* RunModelViewController.mm in Sources */,
59A3D0081CF4E68100C4259F /* ios_image_load.mm in Sources */,
);
runOnlyForDeploymentPostprocessing = 0;
};
/* End PBXSourcesBuildPhase section */
/* Begin XCBuildConfiguration section */
591157B01CF4011D00C31E3A /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
CLANG_CXX_LIBRARY = "libc++";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = dwarf;
ENABLE_STRICT_OBJC_MSGSEND = YES;
ENABLE_TESTABILITY = YES;
GCC_C_LANGUAGE_STANDARD = gnu99;
GCC_DYNAMIC_NO_PIC = NO;
GCC_NO_COMMON_BLOCKS = YES;
GCC_OPTIMIZATION_LEVEL = 0;
GCC_PREPROCESSOR_DEFINITIONS = (
"DEBUG=1",
"$(inherited)",
);
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 8.0;
MTL_ENABLE_DEBUG_INFO = YES;
ONLY_ACTIVE_ARCH = YES;
SDKROOT = iphoneos;
TARGETED_DEVICE_FAMILY = "1,2";
};
name = Debug;
};
591157B11CF4011D00C31E3A /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
ALWAYS_SEARCH_USER_PATHS = NO;
CLANG_CXX_LANGUAGE_STANDARD = "gnu++0x";
CLANG_CXX_LIBRARY = "libc++";
CLANG_ENABLE_MODULES = YES;
CLANG_ENABLE_OBJC_ARC = YES;
CLANG_WARN_BOOL_CONVERSION = YES;
CLANG_WARN_CONSTANT_CONVERSION = YES;
CLANG_WARN_DIRECT_OBJC_ISA_USAGE = YES_ERROR;
CLANG_WARN_EMPTY_BODY = YES;
CLANG_WARN_ENUM_CONVERSION = YES;
CLANG_WARN_INFINITE_RECURSION = YES;
CLANG_WARN_INT_CONVERSION = YES;
CLANG_WARN_OBJC_ROOT_CLASS = YES_ERROR;
CLANG_WARN_SUSPICIOUS_MOVE = YES;
CLANG_WARN_UNREACHABLE_CODE = YES;
CLANG_WARN__DUPLICATE_METHOD_MATCH = YES;
"CODE_SIGN_IDENTITY[sdk=iphoneos*]" = "iPhone Developer";
COPY_PHASE_STRIP = NO;
DEBUG_INFORMATION_FORMAT = "dwarf-with-dsym";
ENABLE_NS_ASSERTIONS = NO;
ENABLE_STRICT_OBJC_MSGSEND = YES;
GCC_C_LANGUAGE_STANDARD = gnu99;
GCC_NO_COMMON_BLOCKS = YES;
GCC_WARN_64_TO_32_BIT_CONVERSION = YES;
GCC_WARN_ABOUT_RETURN_TYPE = YES_ERROR;
GCC_WARN_UNDECLARED_SELECTOR = YES;
GCC_WARN_UNINITIALIZED_AUTOS = YES_AGGRESSIVE;
GCC_WARN_UNUSED_FUNCTION = YES;
GCC_WARN_UNUSED_VARIABLE = YES;
IPHONEOS_DEPLOYMENT_TARGET = 8.0;
MTL_ENABLE_DEBUG_INFO = NO;
SDKROOT = iphoneos;
TARGETED_DEVICE_FAMILY = "1,2";
VALIDATE_PRODUCT = YES;
};
name = Release;
};
591157B31CF4011D00C31E3A /* Debug */ = {
isa = XCBuildConfiguration;
buildSettings = {
CLANG_DEBUG_INFORMATION_LEVEL = default;
CODE_SIGN_IDENTITY = "iPhone Developer";
DEVELOPMENT_TEAM = EQHXZ8M8AV;
ENABLE_BITCODE = NO;
GCC_ENABLE_CPP_EXCEPTIONS = YES;
GCC_ENABLE_CPP_RTTI = YES;
HEADER_SEARCH_PATHS = (
"$(inherited)",
../../../../../../,
../../../downloads/flatbuffers/include/,
../../../downloads/eigen/,
../../../downloads/,
);
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.2;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
LIBRARY_SEARCH_PATHS = ../../../gen/lib/;
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
OTHER_LDFLAGS = "$(inherited)";
PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example";
PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE = "1072bd47-ff19-4e5f-8107-d912748f83f1";
PROVISIONING_PROFILE_SPECIFIER = "Google Development";
SEPARATE_STRIP = NO;
};
name = Debug;
};
591157B41CF4011D00C31E3A /* Release */ = {
isa = XCBuildConfiguration;
buildSettings = {
CLANG_DEBUG_INFORMATION_LEVEL = default;
CODE_SIGN_IDENTITY = "iPhone Developer";
DEVELOPMENT_TEAM = "";
ENABLE_BITCODE = NO;
GCC_ENABLE_CPP_EXCEPTIONS = YES;
GCC_ENABLE_CPP_RTTI = YES;
HEADER_SEARCH_PATHS = (
"$(inherited)",
../../../../../../,
../../../downloads/flatbuffers/include/,
../../../downloads/eigen/,
../../../downloads/,
);
INFOPLIST_FILE = "$(SRCROOT)/RunModel-Info.plist";
IPHONEOS_DEPLOYMENT_TARGET = 9.2;
LD_RUNPATH_SEARCH_PATHS = "$(inherited) @executable_path/Frameworks";
LIBRARY_SEARCH_PATHS = ../../../gen/lib/;
ONLY_ACTIVE_ARCH = YES;
OTHER_CPLUSPLUSFLAGS = "$(OTHER_CFLAGS)";
OTHER_LDFLAGS = "$(inherited)";
PRODUCT_BUNDLE_IDENTIFIER = "com.google.tflite-simple-example";
PRODUCT_NAME = "$(TARGET_NAME)";
PROVISIONING_PROFILE_SPECIFIER = "";
SEPARATE_STRIP = NO;
};
name = Release;
};
/* End XCBuildConfiguration section */
/* Begin XCConfigurationList section */
591157961CF4011C00C31E3A /* Build configuration list for PBXProject "simple" */ = {
isa = XCConfigurationList;
buildConfigurations = (
591157B01CF4011D00C31E3A /* Debug */,
591157B11CF4011D00C31E3A /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
591157B21CF4011D00C31E3A /* Build configuration list for PBXNativeTarget "tf_simple_example" */ = {
isa = XCConfigurationList;
buildConfigurations = (
591157B31CF4011D00C31E3A /* Debug */,
591157B41CF4011D00C31E3A /* Release */,
);
defaultConfigurationIsVisible = 0;
defaultConfigurationName = Release;
};
/* End XCConfigurationList section */
};
rootObject = 591157931CF4011C00C31E3A /* Project object */;
}

View File

@ -267,7 +267,7 @@ try (Interpreter interpreter = new Interpreter(file_of_a_tensorflowlite_model))
The `Interpreter.java` class drives model inference with TensorFlow Lite. In
most of the cases, this is the only class an app developer will need.
#### Initializing an `Interpreter` Mith a Model Mile
#### Initializing an `Interpreter` With a Model File
The `Interpreter` can be initialized with a model file using the constructor:

View File

@ -54,7 +54,7 @@ counterparts:
* [tf.sigmoid](https://www.tensorflow.org/api_docs/python/tf/sigmoid)
* [tf.space_to_depth](https://www.tensorflow.org/api_docs/python/tf/space_to_depth)
## Straighforward Conversions, Constant-Folding and Fusing
## Straightforward Conversions, Constant-Folding and Fusing
A number of TensorFlow operations can be processed by TensorFlow Lite even
though they have no direct equivalent. This is the case for operations that can

View File

@ -0,0 +1,31 @@
#Settings for iOS.
ifeq($(TARGET), IOS) BUILD_FOR_IOS_SIMULATOR
: = false ifeq($(IOS_ARCH), x86_64) BUILD_FOR_IOS_SIMULATOR
: = true endif ifeq($(IOS_ARCH), i386) BUILD_FOR_IOS_SIMULATOR
: = true endif ifeq($(BUILD_FOR_IOS_SIMULATOR), true) IPHONEOS_PLATFORM
: = $(shell xcrun-- sdk iphonesimulator-- show - sdk - platform -
path) IPHONEOS_SYSROOT
: = $(shell xcrun-- sdk iphonesimulator-- show - sdk -
path) else IPHONEOS_PLATFORM
: = $(shell xcrun-- sdk iphoneos-- show - sdk - platform -
path) IPHONEOS_SYSROOT
: = $(shell xcrun-- sdk iphoneos-- show - sdk - path) endif IOS_SDK_VERSION
: = $(shell xcrun-- sdk iphoneos-- show - sdk - version) MIN_SDK_VERSION
: = 9.0
#Override IOS_ARCH with armv7, armv7s, arm64, i386, or x86_64.
IOS_ARCH
: = x86_64 CXXFLAGS
+= -miphoneos - version
- min = $(MIN_SDK_VERSION) - DGEMMLOWP_ALLOW_SLOW_SCALAR_FALLBACK
- fembed - bitcode - Wno - c++ 11 - narrowing - mno - thumb
- fno - exceptions
- isysroot ${IPHONEOS_SYSROOT} - arch $(IOS_ARCH) - O3 CCFLAGS
+= -miphoneos - version
- min = $(MIN_SDK_VERSION) - fembed - bitcode - mno - thumb
- isysroot ${IPHONEOS_SYSROOT} - arch $(IOS_ARCH) -
O3 LDFLAGS
: = -fembed - bitcode - miphoneos - version
- min = ${MIN_SDK_VERSION} - arch $(IOS_ARCH) OBJDIR
: = $(OBJDIR) ios_$(IOS_ARCH) / LIBDIR
: = $(LIBDIR) ios_$(IOS_ARCH) / BINDIR
: = $(BINDIR) ios_$(IOS_ARCH) / DEPDIR : = $(DEPDIR) ios_$(IOS_ARCH) / endif

View File

@ -36,8 +36,8 @@ android {
}
repositories {
flatDir {
dirs 'libs'
maven {
url 'https://google.bintray.com/tensorflow'
}
}

View File

@ -86,25 +86,34 @@ same input.
### Models:
[Speech hotword model (Svdf rank=1)] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_hotword_model_rank1.tflite)
[Speech hotword model (Svdf
rank=1)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank1_2017_11_14.tflite)
[Speech hotword model (Svdf rank=2)] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_hotword_model_rank2.tflite)
[Speech hotword model (Svdf
rank=2)](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_hotword_model_rank2_2017_11_14.tflite)
[Speaker-id model] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_speakerid_model.tflite)
[Speaker-id
model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_speakerid_model_2017_11_14.tflite)
[TTS model] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_tts_model.tflite)
[TTS
model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_tts_model_2017_11_14.tflite)
[ASR AM model] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/testdata/speech_terse_am_model.tflite)
[ASR AM
model](https://storage.googleapis.com/download.tensorflow.org/models/tflite/speech_terse_am_model_2017_11_14.tflite)
### Test benches
[Speech hotword model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc)
[Speech hotword model
test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_hotword_model_test.cc)
[Speaker-id model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc)
[Speaker-id model
test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_speakerid_model_test.cc)
[TTS model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc)
[TTS model
test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_tts_model_test.cc)
[ASR AM model test] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc)
[ASR AM model
test](https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/models/speech_terse_am_model_test.cc)
## Android Support
The models have been tested on Android phones, using the following tests:
@ -112,5 +121,3 @@ The models have been tested on Android phones, using the following tests:
[Hotword] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=25)
[Speaker-id] (https://github.com/tensorflow/tensorflow/tree/master/tensorflow/contrib/lite/android/BUILD?rcl=172930882&l=36)

View File

@ -1454,9 +1454,9 @@ inline int ANeuralNetworksModel_finish(ANeuralNetworksModel* model) {
* {@link ANeuralNetworksExecution_setOutputFromMemory} and
* {@link ANeuralNetworksExecution_setOperandValue}.
*
* To build a model that can accomodate inputs of various sizes, as you may want
* to do for a CNN, set the size of the dimensions that will vary at run time to
* 0. If you do so, provide the full dimensions when calling
* To build a model that can accommodate inputs of various sizes, as you may
* want to do for a CNN, set the size of the dimensions that will vary at run
* time to 0. If you do so, provide the full dimensions when calling
* {@link ANeuralNetworksExecution_setInput} or {@link
* ANeuralNetworksExecution_setInputFromMemory}.
*

View File

@ -252,7 +252,7 @@ def JsonDumpAndFlush(data, fp):
class TestSchemaUpgrade(test_util.TensorFlowTestCase):
def testNonExistantFile(self):
def testNonExistentFile(self):
converter = upgrade_schema_lib.Converter()
non_existent = tempfile.mktemp(suffix=".json")
with self.assertRaisesRegexp(IOError, "No such file or directory"):

View File

@ -187,6 +187,7 @@ tf_cc_test(
srcs = ["generated_examples_zip_test.cc"],
data = [":optest"],
shard_count = 10,
tags = ["no_oss"],
deps = [
":parse_testdata_lib",
"//tensorflow/contrib/lite:builtin_op_data",

View File

@ -232,7 +232,7 @@ TfLiteStatus CheckOutputs(tflite::Interpreter* interpreter,
// invoke {
// id: xyz
// input: 1,2,1,1,1,2,3,4
// ouput: 4,5,6
// output: 4,5,6
// }
class Invoke : public Message {
public:

View File

@ -63,7 +63,7 @@ class TestRunner {
// Run the model.
virtual void Invoke() = 0;
// Verify that the contents of all ouputs conform to the existing
// Verify that the contents of all outputs conform to the existing
// expectations. Return true if there are no expectations or they are all
// satisfied.
virtual bool CheckResults() = 0;

View File

@ -129,7 +129,7 @@ enum class AxesOrder {
// The type of the scalars in an array.
// Note that that does not by itself tell whether the values in the array are
// real (are literally interpreted as real numbers) or quantized (only acquire
// a meaning as real numbers in conjuction with QuantizationParams).
// a meaning as real numbers in conjunction with QuantizationParams).
//
// In practice though:
// float values are always real

View File

@ -0,0 +1,95 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include <cstdarg>
#include <cstdlib>
#include <iostream>
#include <memory>
#include <string>
#include <unordered_set>
#include <vector>
#include "tensorflow/contrib/lite/kernels/register.h"
#include "tensorflow/contrib/lite/model.h"
#include "tensorflow/contrib/lite/string_util.h"
#include "tensorflow/contrib/lite/tools/mutable_op_resolver.h"
#ifdef TFLITE_CUSTOM_OPS_HEADER
void RegisterSelectedOps(::tflite::MutableOpResolver* resolver);
#endif
#define LOG(x) std::cerr
#define CHECK(x) \
if (!(x)) { \
LOG(ERROR) << #x << "failed"; \
exit(1); \
}
namespace tensorflow {
namespace benchmark_tflite_model {
std::unique_ptr<tflite::FlatBufferModel> model;
std::unique_ptr<tflite::Interpreter> interpreter;
void InitImpl(const std::string& graph, const std::vector<int>& sizes,
const std::string& input_layer_type, int num_threads) {
CHECK(graph.c_str());
model = tflite::FlatBufferModel::BuildFromFile(graph.c_str());
if (!model) {
LOG(FATAL) << "Failed to mmap model " << graph;
}
LOG(INFO) << "Loaded model " << graph;
model->error_reporter();
LOG(INFO) << "resolved reporter";
#ifdef TFLITE_CUSTOM_OPS_HEADER
tflite::MutableOpResolver resolver;
RegisterSelectedOps(&resolver);
#else
tflite::ops::builtin::BuiltinOpResolver resolver;
#endif
tflite::InterpreterBuilder(*model, resolver)(&interpreter);
if (!interpreter) {
LOG(FATAL) << "Failed to construct interpreter";
}
if (num_threads != -1) {
interpreter->SetNumThreads(num_threads);
}
int input = interpreter->inputs()[0];
if (input_layer_type != "string") {
interpreter->ResizeInputTensor(input, sizes);
}
if (interpreter->AllocateTensors() != kTfLiteOk) {
LOG(FATAL) << "Failed to allocate tensors!";
}
}
int Main(int argc, char** argv) {
InitImpl("", {}, "", 1);
return 0;
}
} // namespace benchmark_tflite_model
} // namespace tensorflow
int main(int argc, char** argv) {
return tensorflow::benchmark_tflite_model::Main(argc, argv);
}

View File

@ -19,6 +19,16 @@ limitations under the License.
#include "tensorflow/contrib/lite/context.h"
#include "tensorflow/contrib/lite/model.h"
// Needed to resolve unordered_set hash on older compilers.
namespace std {
template <>
struct hash<tflite::BuiltinOperator> {
size_t operator()(const tflite::BuiltinOperator& op) const {
return std::hash<int>()(op);
}
};
} // namespace std
namespace tflite {
// An OpResolver that is mutable, also used as the op in gen_op_registration.

View File

@ -72,6 +72,7 @@ cc_library(
"//tensorflow/core:worker_proto_cc",
"//tensorflow/core/distributed_runtime:base_rendezvous_mgr",
"//tensorflow/core/distributed_runtime:session_mgr",
"//tensorflow/core/distributed_runtime:tensor_coding",
"//tensorflow/core/distributed_runtime:worker_env",
"//third_party/mpi",
],

View File

@ -116,7 +116,7 @@ def deprecated_flipped_sparse_softmax_cross_entropy_with_logits(logits,
Raises:
ValueError: If logits are scalars (need to have rank >= 1) or if the rank
of the labels is not equal to the rank of the labels minus one.
of the labels is not equal to the rank of the logits minus one.
"""
return nn.sparse_softmax_cross_entropy_with_logits(
labels=labels, logits=logits, name=name)

View File

@ -34,7 +34,7 @@ def _rank_resample(weights, biases, inputs, sampled_values, num_resampled,
log(sum_j exp((w_i * x_j + b_i) / resampling_temperature))
where w_i, b_i are the weight and bias of the i-th class, repsectively,
where w_i, b_i are the weight and bias of the i-th class, respectively,
and j ranges over the rows of `inputs`. For efficiency, we rearrange the
computation to

View File

@ -114,7 +114,6 @@ class CoupledInputForgetGateLSTMCell(rnn_cell_impl.RNNCell):
The class uses optional peep-hole connections, and an optional projection
layer.
Layer normalization implementation is based on:
https://arxiv.org/abs/1607.06450.

View File

@ -441,7 +441,8 @@ module. Consider the simple case where we want to train the VGG network:
```python
import tensorflow as tf
vgg = tf.contrib.slim.nets.vgg
import tensorflow.contrib.slim.nets as nets
vgg = nets.vgg
# Load the images and labels.
images, labels = ...
@ -559,9 +560,10 @@ examine the following sample of training the VGG network:
```python
import tensorflow as tf
import tensorflow.contrib.slim.nets as nets
slim = tf.contrib.slim
vgg = tf.contrib.slim.nets.vgg
vgg = nets.vgg
...
@ -809,9 +811,10 @@ Putting it all together:
```python
import tensorflow as tf
import tensorflow.contrib.slim.nets as nets
slim = tf.contrib.slim
vgg = tf.contrib.slim.nets.vgg
vgg = nets.vgg
# Load the data

View File

@ -34,7 +34,7 @@ the metrics and finally call the `evaluation` method:
"mse": slim.metrics.mean_squared_error(predictions, labels),
})
inital_op = tf.group(
initial_op = tf.group(
tf.global_variables_initializer(),
tf.local_variables_initializer())
@ -42,7 +42,7 @@ the metrics and finally call the `evaluation` method:
metric_values = slim.evaluation(
sess,
num_evals=1,
inital_op=initial_op,
initial_op=initial_op,
eval_op=names_to_updates.values(),
final_op=name_to_values.values())

View File

@ -25,7 +25,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
":summary_test_internal",
":summary_test_util",
"//tensorflow/python:array_ops",
"//tensorflow/python:errors",
@ -46,7 +45,6 @@ py_test(
srcs_version = "PY2AND3",
deps = [
":summary_ops",
":summary_test_internal",
":summary_test_util",
"//tensorflow/core:protos_all_py",
"//tensorflow/python:array_ops",
@ -119,15 +117,3 @@ py_library(
"//tensorflow/python:platform",
],
)
py_library(
name = "summary_test_internal",
testonly = 1,
srcs = ["summary_test_internal.py"],
srcs_version = "PY2AND3",
visibility = ["//visibility:private"],
deps = [
"//tensorflow/python:lib",
"//tensorflow/python:platform",
],
)

View File

@ -21,7 +21,6 @@ import tempfile
import six
from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_internal
from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
@ -33,10 +32,10 @@ from tensorflow.python.ops import control_flow_ops
from tensorflow.python.platform import test
from tensorflow.python.training import training_util
get_all = summary_test_internal.get_all
get_all = summary_test_util.get_all
class DbTest(summary_test_internal.SummaryDbTest):
class DbTest(summary_test_util.SummaryDbTest):
def testGraphPassedToGraph_isForbiddenForThineOwnSafety(self):
with self.assertRaises(TypeError):

View File

@ -21,7 +21,6 @@ import tempfile
import six
from tensorflow.contrib.summary import summary_ops
from tensorflow.contrib.summary import summary_test_internal
from tensorflow.contrib.summary import summary_test_util
from tensorflow.core.framework import graph_pb2
from tensorflow.core.framework import node_def_pb2
@ -35,8 +34,8 @@ from tensorflow.python.ops import state_ops
from tensorflow.python.platform import gfile
from tensorflow.python.training import training_util
get_all = summary_test_internal.get_all
get_one = summary_test_internal.get_one
get_all = summary_test_util.get_all
get_one = summary_test_util.get_one
class TargetTest(test_util.TensorFlowTestCase):
@ -137,7 +136,7 @@ class TargetTest(test_util.TensorFlowTestCase):
self.assertEqual(3, get_total())
class DbTest(summary_test_internal.SummaryDbTest):
class DbTest(summary_test_util.SummaryDbTest):
def testIntegerSummaries(self):
step = training_util.create_global_step()

View File

@ -19,13 +19,38 @@ from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import functools
import os
import sqlite3
from tensorflow.contrib.summary import summary_ops
from tensorflow.core.util import event_pb2
from tensorflow.python.framework import test_util
from tensorflow.python.lib.io import tf_record
from tensorflow.python.platform import gfile
class SummaryDbTest(test_util.TensorFlowTestCase):
"""Helper for summary database testing."""
def setUp(self):
super(SummaryDbTest, self).setUp()
self.db_path = os.path.join(self.get_temp_dir(), 'DbTest.sqlite')
if os.path.exists(self.db_path):
os.unlink(self.db_path)
self.db = sqlite3.connect(self.db_path)
self.create_summary_db_writer = functools.partial(
summary_ops.create_summary_db_writer,
db_uri=self.db_path,
experiment_name='experiment',
run_name='run',
user_name='user')
def tearDown(self):
self.db.close()
super(SummaryDbTest, self).tearDown()
def events_from_file(filepath):
"""Returns all events in a single event file.
@ -58,5 +83,17 @@ def events_from_logdir(logdir):
"""
assert gfile.Exists(logdir)
files = gfile.ListDirectory(logdir)
assert len(files) == 1, "Found not exactly one file in logdir: %s" % files
assert len(files) == 1, 'Found not exactly one file in logdir: %s' % files
return events_from_file(os.path.join(logdir, files[0]))
def get_one(db, q, *p):
return db.execute(q, p).fetchone()[0]
def get_all(db, q, *p):
return unroll(db.execute(q, p).fetchall())
def unroll(list_of_tuples):
return sum(list_of_tuples, ())

View File

@ -80,7 +80,7 @@ class DecisionsToDataThenNNTest(test_util.TensorFlowTestCase):
isinstance(self.params.num_trees, tensor_forest.ForestHParams))
with variable_scope.variable_scope(
"DecisionsToDataThenNNTest_testContructionPollution"):
"DecisionsToDataThenNNTest_testConstructionPollution"):
graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
self.params)
@ -95,7 +95,7 @@ class DecisionsToDataThenNNTest(test_util.TensorFlowTestCase):
for _ in range(100)])
with variable_scope.variable_scope(
"DecisionsToDataThenNNTest_testInferenceContruction"):
"DecisionsToDataThenNNTest_testInferenceConstruction"):
graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
self.params)
graph = graph_builder.inference_graph(data, None)
@ -111,7 +111,7 @@ class DecisionsToDataThenNNTest(test_util.TensorFlowTestCase):
labels = [1 for _ in range(100)]
with variable_scope.variable_scope(
"DecisionsToDataThenNNTest_testTrainingContruction"):
"DecisionsToDataThenNNTest_testTrainingConstruction"):
graph_builder = decisions_to_data_then_nn.DecisionsToDataThenNN(
self.params)
graph = graph_builder.training_graph(data, labels, None)

View File

@ -455,6 +455,7 @@ tf_cuda_library(
"util/mirror_pad_mode.h",
"util/padding.h",
"util/port.h",
"util/ptr_util.h",
"util/reffed_status_callback.h",
"util/saved_tensor_slice_util.h",
"util/sparse/group_iterator.h",
@ -493,6 +494,11 @@ cc_library(
],
)
cc_library(
name = "ptr_util",
hdrs = ["util/ptr_util.h"],
)
cc_library(
name = "reader_base",
srcs = ["framework/reader_base.cc"],

View File

@ -455,7 +455,6 @@ class Graph {
// the corresponding NodeDef to reflect the change.
// REQUIRES: The control edge must exist.
void RemoveControlEdge(const Edge* e);
// Updates the input to a node. The existing edge to `dst` is removed and an
// edge from `new_src` to `dst` is created. The NodeDef associated with `dst`
// is also updated.

View File

@ -1068,7 +1068,7 @@ Status ArithmeticOptimizer::SimplifyArithmeticOps() {
if (simplified_node != nullptr) {
nodes_to_simplify.PushBack(simplified_node);
}
// When `node` is simplifed to another node rather than in-place, the
// When `node` is simplified to another node rather than in-place, the
// consumers of `node` are already redirected to `simplified_tensor`.
// Re-push the consumers into `nodes_to_simplify` for further
// optimizations.

View File

@ -2583,8 +2583,13 @@ tf_kernel_library(
tf_kernel_library(
name = "batch_matmul_op",
srcs = [] + if_mkl([
"mkl_batch_matmul_op.cc",
]),
prefix = "batch_matmul_op",
deps = MATH_DEPS,
deps = MATH_DEPS + if_mkl([
"//third_party/mkl:intel_binary_blob",
]),
)
tf_kernel_library(
@ -6325,11 +6330,11 @@ cc_library(
srcs = ["summary_interface.cc"],
hdrs = ["summary_interface.h"],
deps = [
"//tensorflow/compiler/xla:util",
"//tensorflow/core:framework",
"//tensorflow/core:lib",
"//tensorflow/core:lib_internal",
"//tensorflow/core:protos_all_cc",
"//tensorflow/core:ptr_util",
],
)

View File

@ -17,8 +17,10 @@ limitations under the License.
namespace tensorflow {
#if !defined(INTEL_MKL)
TF_CALL_complex64(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_CPU);
#endif
#if GOOGLE_CUDA
TF_CALL_complex64(REGISTER_BATCH_MATMUL_GPU);

View File

@ -17,8 +17,10 @@ limitations under the License.
namespace tensorflow {
#if !defined(INTEL_MKL)
TF_CALL_float(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_double(REGISTER_BATCH_MATMUL_CPU);
#endif
TF_CALL_half(REGISTER_BATCH_MATMUL_CPU);
TF_CALL_int32(REGISTER_BATCH_MATMUL_CPU);

View File

@ -4,7 +4,7 @@ Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,

View File

@ -34,8 +34,10 @@ class DecodeBmpOp : public OpKernel {
explicit DecodeBmpOp(OpKernelConstruction* context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("channels", &channels_));
OP_REQUIRES(
context, channels_ == 0 || channels_ == 3 || channels_ == 4,
errors::InvalidArgument("channels must be 0, 3 or 4, got ", channels_));
context,
channels_ == 0 || channels_ == 1 || channels_ == 3 || channels_ == 4,
errors::InvalidArgument("channels must be 0, 1, 3 or 4, got ",
channels_));
}
void Compute(OpKernelContext* context) override {
@ -66,11 +68,11 @@ class DecodeBmpOp : public OpKernel {
channels_ = bpp / 8;
}
// Current implementation only supports 3 or 4 channel
// Current implementation only supports 1, 3 or 4 channel
// bitmaps.
OP_REQUIRES(context, (channels_ == 3 || channels_ == 4),
OP_REQUIRES(context, (channels_ == 1 || channels_ == 3 || channels_ == 4),
errors::InvalidArgument(
"Number of channels must be 3 or 4, was ", channels_));
"Number of channels must be 1, 3 or 4, was ", channels_));
// if height is negative, data layout is top down
// otherwise, it's bottom up
@ -117,6 +119,9 @@ uint8* DecodeBmpOp::Decode(const uint8* input, uint8* const output,
dst_pos = (i * width + j) * channels;
switch (channels) {
case 1:
output[dst_pos] = input[src_pos];
break;
case 3:
// BGR -> RGB
output[dst_pos] = input[src_pos + 2];

View File

@ -16,6 +16,7 @@ limitations under the License.
#include <functional>
#include <memory>
#include "tensorflow/core/common_runtime/kernel_benchmark_testlib.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/fake_input.h"
#include "tensorflow/core/framework/node_def_builder.h"
@ -23,10 +24,14 @@ limitations under the License.
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/framework/types.pb.h"
#include "tensorflow/core/graph/node_builder.h"
#include "tensorflow/core/graph/testlib.h"
#include "tensorflow/core/kernels/ops_testutil.h"
#include "tensorflow/core/kernels/ops_util.h"
#include "tensorflow/core/lib/core/status_test_util.h"
#include "tensorflow/core/lib/random/simple_philox.h"
#include "tensorflow/core/platform/test.h"
#include "tensorflow/core/platform/test_benchmark.h"
namespace tensorflow {
namespace {
@ -153,5 +158,58 @@ TEST_F(DynamicPartitionOpTest, Error_IndexOutOfRange) {
<< s;
}
Node* DynamicPartitionNode(Graph* g, Node* in0, Node* in1, int num_partitions) {
Node* ret;
TF_CHECK_OK(NodeBuilder(g->NewName("n"), "DynamicPartition")
.Input(in0)
.Input(in1)
.Attr("num_partitions", num_partitions)
.Finalize(g, &ret));
return ret;
}
template <typename T>
static Graph* DynamicPartition(int num_partitions, int dim) {
Graph* g = new Graph(OpRegistry::Global());
// Always use a 128MB buffer.
const int kRows = ((128 << 20) / sizeof(T)) / dim;
Tensor data(DataTypeToEnum<T>::value, TensorShape({kRows, dim}));
data.flat<T>().setRandom();
random::PhiloxRandom philox(301, 17);
random::SimplePhilox rnd(&philox);
Tensor partitions(DT_INT32, TensorShape({kRows}));
for (int i = 0; i < kRows; i++) {
partitions.flat<int32>()(i) = rnd.Uniform(num_partitions);
}
DynamicPartitionNode(g, test::graph::Constant(g, data),
test::graph::Constant(g, partitions), num_partitions);
return g;
}
#define BM_DYNAMIC_PARTITION(DEVICE, T, num) \
static void BM_##DEVICE##_dynpart_##T##_##num(int iters, int dim) { \
const int64 items = ((128 << 20) / sizeof(T)); \
const int64 tot = static_cast<int64>(iters) * items; \
testing::ItemsProcessed(tot); \
testing::UseRealTime(); \
test::Benchmark(#DEVICE, DynamicPartition<T>(num, dim)).Run(iters); \
} \
BENCHMARK(BM_##DEVICE##_dynpart_##T##_##num)->Arg(1)->Arg(256)
BM_DYNAMIC_PARTITION(cpu, float, 2);
BM_DYNAMIC_PARTITION(cpu, float, 100);
BM_DYNAMIC_PARTITION(cpu, double, 2);
BM_DYNAMIC_PARTITION(cpu, double, 100);
BM_DYNAMIC_PARTITION(cpu, complex64, 2);
BM_DYNAMIC_PARTITION(cpu, complex64, 100);
BM_DYNAMIC_PARTITION(gpu, float, 2);
BM_DYNAMIC_PARTITION(gpu, float, 100);
BM_DYNAMIC_PARTITION(gpu, double, 2);
BM_DYNAMIC_PARTITION(gpu, double, 100);
BM_DYNAMIC_PARTITION(gpu, complex64, 2);
BM_DYNAMIC_PARTITION(gpu, complex64, 100);
} // namespace
} // namespace tensorflow

View File

@ -0,0 +1,238 @@
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// See docs in ../ops/math_ops.cc.
// This file uses MKL CBLAS batched xGEMM for acceleration of TF Batch
// Matrix-Matrix Multiplication (MatMul) operations.
// We currently register this kernel only for MKL supported data
// types (float, double, complex64, complex128). The macro INTEL_MKL is defined
// by the build system only when MKL is chosen as an option at configure stage
// and when it is undefined at build time, this file becomes an empty
// compilation unit
#define EIGEN_USE_THREADS
#if defined(INTEL_MKL)
#include <vector>
#include "mkl_cblas.h"
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
#include "tensorflow/core/framework/numeric_types.h"
#include "tensorflow/core/framework/op.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/register_types.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
#include "tensorflow/core/framework/type_traits.h"
#include "tensorflow/core/framework/types.h"
#include "tensorflow/core/kernels/fill_functor.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/types.h"
#define MKL_Complex8 tensorflow::complex64
#define MKL_Complex16 tensorflow::complex128
namespace tensorflow {
typedef Eigen::ThreadPoolDevice CPUDevice;
template <typename Device, typename Scalar>
class BatchMatMulMkl : public OpKernel {
public:
explicit BatchMatMulMkl(OpKernelConstruction *context) : OpKernel(context) {
OP_REQUIRES_OK(context, context->GetAttr("adj_x", &adj_x_));
OP_REQUIRES_OK(context, context->GetAttr("adj_y", &adj_y_));
}
virtual ~BatchMatMulMkl() {}
void Compute(OpKernelContext *ctx) override {
const Tensor &lhs = ctx->input(0);
const Tensor &rhs = ctx->input(1);
OP_REQUIRES(ctx, lhs.dims() == rhs.dims(),
errors::InvalidArgument("lhs and rhs has different ndims: ",
lhs.shape().DebugString(), " vs. ",
rhs.shape().DebugString()));
const int ndims = lhs.dims();
OP_REQUIRES(
ctx, ndims >= 2,
errors::InvalidArgument("lhs and rhs ndims must be >= 2: ", ndims));
TensorShape out_shape;
for (int i = 0; i < ndims - 2; ++i) {
OP_REQUIRES(ctx, lhs.dim_size(i) == rhs.dim_size(i),
errors::InvalidArgument(
"lhs.dim(", i, ") and rhs.dim(", i,
") must be the same: ", lhs.shape().DebugString(), " vs ",
rhs.shape().DebugString()));
out_shape.AddDim(lhs.dim_size(i));
}
auto batch_size = (ndims == 2) ? 1 : out_shape.num_elements();
auto lhs_rows = lhs.dim_size(ndims - 2);
auto lhs_cols = lhs.dim_size(ndims - 1);
auto rhs_rows = rhs.dim_size(ndims - 2);
auto rhs_cols = rhs.dim_size(ndims - 1);
if (adj_x_) std::swap(lhs_rows, lhs_cols);
if (adj_y_) std::swap(rhs_rows, rhs_cols);
OP_REQUIRES(ctx, lhs_cols == rhs_rows,
errors::InvalidArgument(
"lhs mismatch rhs shape: ", lhs_cols, " vs. ", rhs_rows,
": ", lhs.shape().DebugString(), " ",
rhs.shape().DebugString(), " ", adj_x_, " ", adj_y_));
out_shape.AddDim(lhs_rows);
out_shape.AddDim(rhs_cols);
Tensor *out = nullptr;
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, out_shape, &out));
if (out->NumElements() == 0) {
return;
}
if (lhs.NumElements() == 0 || rhs.NumElements() == 0) {
functor::SetZeroFunctor<Device, Scalar> f;
f(ctx->eigen_device<Device>(), out->flat<Scalar>());
return;
}
auto rhs_reshaped = rhs.template flat_inner_dims<Scalar, 3>();
auto lhs_reshaped = lhs.template flat_inner_dims<Scalar, 3>();
auto out_reshaped = out->template flat_inner_dims<Scalar, 3>();
const uint64 M = lhs_reshaped.dimension(adj_x_ ? 2 : 1);
const uint64 K = lhs_reshaped.dimension(adj_x_ ? 1 : 2);
const uint64 N = rhs_reshaped.dimension(adj_y_ ? 1 : 2);
std::vector<MKL_INT> m_array(batch_size, M);
std::vector<MKL_INT> n_array(batch_size, N);
std::vector<MKL_INT> k_array(batch_size, K);
std::vector<MKL_INT> lda_array(batch_size, adj_x_ ? M : K);
std::vector<MKL_INT> ldb_array(batch_size, adj_y_ ? K : N);
std::vector<MKL_INT> ldc_array(batch_size, N);
std::vector<MKL_INT> group_size(1, batch_size);
std::vector<const Scalar *> a_array;
std::vector<const Scalar *> b_array;
std::vector<Scalar *> c_array;
a_array.reserve(batch_size);
b_array.reserve(batch_size);
c_array.reserve(batch_size);
for (int64 i = 0; i < batch_size; i++) {
a_array.push_back(&lhs_reshaped(i, 0, 0));
b_array.push_back(&rhs_reshaped(i, 0, 0));
c_array.push_back(&out_reshaped(i, 0, 0));
}
MklCblasGemmBatch(CblasRowMajor, adj_x_, adj_y_, &m_array[0], &n_array[0],
&k_array[0], &a_array[0], &lda_array[0], &b_array[0],
&ldb_array[0], &c_array[0], &ldc_array[0], 1,
&group_size[0]);
}
private:
bool adj_x_;
bool adj_y_;
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array,
const float **A_Array, const MKL_INT *lda_Array,
const float **B_Array, const MKL_INT *ldb_Array,
float **C_Array, const MKL_INT *ldc_Array,
const MKL_INT group_count, const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_Array(
group_size[0], TransA ? CblasTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_Array(
group_size[0], TransB ? CblasTrans : CblasNoTrans);
std::vector<float> alpha_Array(group_size[0], 1.0);
std::vector<float> beta_Array(group_size[0], 0.0);
cblas_sgemm_batch(Layout, &TransA_Array[0], &TransB_Array[0], M_Array,
N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array,
B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array,
group_count, group_size);
}
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array,
const double **A_Array, const MKL_INT *lda_Array,
const double **B_Array, const MKL_INT *ldb_Array,
double **C_Array, const MKL_INT *ldc_Array,
const MKL_INT group_count, const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasTrans : CblasNoTrans);
std::vector<double> alpha_Array(group_size[0], 1.0);
std::vector<double> beta_Array(group_size[0], 0.0);
cblas_dgemm_batch(Layout, &TransA_array[0], &TransB_array[0], M_Array,
N_Array, K_Array, &alpha_Array[0], A_Array, lda_Array,
B_Array, ldb_Array, &beta_Array[0], C_Array, ldc_Array,
group_count, group_size);
}
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array,
const MKL_Complex8 **A_Array, const MKL_INT *lda_Array,
const MKL_Complex8 **B_Array, const MKL_INT *ldb_Array,
MKL_Complex8 **C_Array, const MKL_INT *ldc_Array,
const MKL_INT group_count, const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
std::vector<MKL_Complex8> alpha_Array(group_size[0], {1.0f, 0.0f});
std::vector<MKL_Complex8> beta_Array(group_size[0], {0.0f, 0.0f});
cblas_cgemm_batch(
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
static_cast<const void *>(&alpha_Array[0]),
reinterpret_cast<const void **>(A_Array), lda_Array,
reinterpret_cast<const void **>(B_Array), ldb_Array,
static_cast<const void *>(&beta_Array[0]),
reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size);
}
void MklCblasGemmBatch(const CBLAS_LAYOUT Layout, const bool TransA,
const bool TransB, const MKL_INT *M_Array,
const MKL_INT *N_Array, const MKL_INT *K_Array,
const MKL_Complex16 **A_Array,
const MKL_INT *lda_Array,
const MKL_Complex16 **B_Array,
const MKL_INT *ldb_Array, MKL_Complex16 **C_Array,
const MKL_INT *ldc_Array, const MKL_INT group_count,
const MKL_INT *group_size) {
std::vector<CBLAS_TRANSPOSE> TransA_array(
group_size[0], TransA ? CblasConjTrans : CblasNoTrans);
std::vector<CBLAS_TRANSPOSE> TransB_array(
group_size[0], TransB ? CblasConjTrans : CblasNoTrans);
std::vector<MKL_Complex16> alpha_Array(group_size[0], {1.0f, 0.0f});
std::vector<MKL_Complex16> beta_Array(group_size[0], {0.0f, 0.0f});
cblas_zgemm_batch(
Layout, &TransA_array[0], &TransB_array[0], M_Array, N_Array, K_Array,
static_cast<const void *>(&alpha_Array[0]),
reinterpret_cast<const void **>(A_Array), lda_Array,
reinterpret_cast<const void **>(B_Array), ldb_Array,
static_cast<const void *>(&beta_Array[0]),
reinterpret_cast<void **>(C_Array), ldc_Array, group_count, group_size);
}
};
#define REGISTER_BATCH_MATMUL_MKL(TYPE) \
REGISTER_KERNEL_BUILDER( \
Name("BatchMatMul").Device(DEVICE_CPU).TypeConstraint<TYPE>("T"), \
BatchMatMulMkl<CPUDevice, TYPE>)
TF_CALL_float(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_double(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex64(REGISTER_BATCH_MATMUL_MKL);
TF_CALL_complex128(REGISTER_BATCH_MATMUL_MKL);
} // end namespace tensorflow
#endif

View File

@ -37,6 +37,8 @@ class PrefetchDatasetOp : public UnaryDatasetOpKernel {
int64 buffer_size;
OP_REQUIRES_OK(
ctx, ParseScalarArgument<int64>(ctx, "buffer_size", &buffer_size));
OP_REQUIRES(ctx, buffer_size > 0,
errors::InvalidArgument("buffer_size must be > 0"));
*output = new Dataset(ctx, input, buffer_size);
}

View File

@ -16,7 +16,6 @@ limitations under the License.
#include <utility>
#include "tensorflow/compiler/xla/ptr_util.h"
#include "tensorflow/core/framework/graph.pb.h"
#include "tensorflow/core/framework/op_kernel.h"
#include "tensorflow/core/framework/resource_mgr.h"
@ -28,6 +27,7 @@ limitations under the License.
#include "tensorflow/core/lib/png/png_io.h"
#include "tensorflow/core/lib/wav/wav_io.h"
#include "tensorflow/core/util/events_writer.h"
#include "tensorflow/core/util/ptr_util.h"
namespace tensorflow {
namespace {
@ -229,7 +229,7 @@ class SummaryWriterImpl : public SummaryWriterInterface {
}
mutex_lock ml(mu_);
events_writer_ =
xla::MakeUnique<EventsWriter>(io::JoinPath(logdir, "events"));
tensorflow::MakeUnique<EventsWriter>(io::JoinPath(logdir, "events"));
if (!events_writer_->InitWithSuffix(filename_suffix)) {
return errors::Unknown("Could not initialize events writer.");
}

Some files were not shown because too many files have changed in this diff Show More