Branch 152232810 (#8988)
* Improve py_func error handling. Automatically translate some python errors into corresponding TF errors at runtime. Change: 152156821 * Update interaction with libpng so that we use the public API instead of knowledge of the internal libpng data structures. Change: 152167754 * TensorBoard plugins now contain their own name/route prefix. Change: 152167807 * Passes trainable flag to separable_conv2d biases. Change: 152170239 * Saving resource variables with a caching device. Change: 152171539 * Drop loss from estimator_spec.eval_metric_ops, as required by core Estimator. Change: 152179924 * sample_stats.percentile DOCFIX. Change: 152182295 * Added a memory optimizer to grappler. Change: 152184170 * Change default behavior of the tf runs selector: - If there are fewer than 41 runs, enable them all by default - If there are 41 runs or more, disable them all by default This is in response to user complaints that having it enable only the first ten runs by default was confusing, because it was not obvious to users that some runs had been disabled. However, it still solves the initial user complaint that having very many runs simultaneously enabled would lag the UI. I also changed the "toggle all runs" button to try to turn everything off before turning everything on. Also, I improved the logic for detecting when the runs selection is back in the default state, so that we can avoid generating long URI strings wherever possible. Change: 152188948 * Autogenerated Change: Change TensorBoard TAG to 52 Change: 152189000 * Remove warning that only happening with config cuda. Change: 152189205 * Make resource variable shared name consistent with non-resource variables. Remove colocation constraint from resource variable cached value with the variable itself. Change: 152192203 * Add a way to specify the optimization order; refactor and add constant folding to meta optimizer. Change: 152193646 * Backport fixes and improvements from external Keras. Change: 152198296 * Merge changes from github. Change: 152200430 * Go: Update generated wrapper functions for TensorFlow ops. Change: 152200754 * Update ops-related pbtxt files. Change: 152203174 * Make ImportGraphDef() work with functions. In addition to modify graph_constructor.cc, this patch adds some other functionality to enable importing fucntions: * Ability to add FunctionDefLibraries to Graphs and FunctionLibraryDefinitions (in addition to existing functions) * FunctionDefsEqual() utility function Change: 152205258 * Expand contrib test to more than just test targets. Change: 152206822 * Preserve graph version during optimization Change: 152213262 * Exclude enter and exit nodes from shape refiner's constant folding. Change: 152213637 * Allow reshape_mover and algebraic_simplifier to make multiple mutations, by avoiding the short-circuit std::any_of. Change: 152232810 * fixing workspace.bzl * workspace.bzl further fixes * fixing tensorflow.bzl merge conflicts * fixing typo in dnn.h * fixing bad merge for dnn.h
This commit is contained in:
parent
b93a88f86d
commit
e69f71759a
tensorflow
compiler
contrib
distributions/python/ops
keras/python/keras
__init__.pyactivations.py
applications
backend.pyengine
initializers.pylayers
convolutional.pyconvolutional_recurrent.pycore.pylocal.pymerge.pynormalization.pypooling.pyrecurrent.pywrappers.py
metrics.pymodels.pypreprocessing
utils
wrappers
layers/python/layers
learn/python/learn
dataframe/queues
estimators
learn_io
opt/python/training
seq2seq/python/ops
core
common_runtime
framework
graph
graph.ccgraph.hgraph_constructor.ccgraph_constructor.hgraph_constructor_test.ccgraph_test.ccmkl_layout_pass.ccmkl_layout_pass_test.ccmkl_tfconversion_pass.ccmkl_tfconversion_pass_test.cc
grappler/optimizers
BUILDconstant_folding.ccmemory_optimizer.ccmemory_optimizer.hmemory_optimizer_test.ccmeta_optimizer.ccmeta_optimizer.h
kernels
conv_grad_filter_ops.ccconv_grad_input_ops.cccudnn_pooling_gpu.ccmaxpooling_op.ccmaxpooling_op_gpu.cu.ccmaxpooling_op_gpu.hmkl_avgpooling_op.ccmkl_conv_grad_bias_ops.ccmkl_conv_grad_filter_ops.ccmkl_conv_grad_input_ops.ccmkl_conv_ops.ccmkl_maxpooling_op.ccmkl_pooling_ops_common.ccmkl_pooling_ops_common.hmkl_relu_op.ccpooling_ops_3d.ccpooling_ops_3d_gpu.cu.ccpooling_ops_common.ccxsmm_conv2d.cc
lib
ops
platform
protobuf
util
go
python
estimator
framework
kernel_tests
@ -116,13 +116,14 @@ class NAryOpsTest(XLATestCase):
|
||||
np.array([1, 1], dtype=np.int32)],
|
||||
expected=np.array([[], []], dtype=np.float32))
|
||||
|
||||
if (np.int64 in self.int_types):
|
||||
self._testNAry(lambda x: array_ops.strided_slice(*x),
|
||||
[np.array([[], [], []], dtype=np.float32),
|
||||
np.array([1, 0], dtype=np.int64),
|
||||
np.array([3, 0], dtype=np.int64),
|
||||
np.array([1, 1], dtype=np.int64)],
|
||||
expected=np.array([[], []], dtype=np.float32))
|
||||
if np.int64 in self.int_types:
|
||||
self._testNAry(
|
||||
lambda x: array_ops.strided_slice(*x), [
|
||||
np.array([[], [], []], dtype=np.float32), np.array(
|
||||
[1, 0], dtype=np.int64), np.array([3, 0], dtype=np.int64),
|
||||
np.array([1, 1], dtype=np.int64)
|
||||
],
|
||||
expected=np.array([[], []], dtype=np.float32))
|
||||
|
||||
self._testNAry(lambda x: array_ops.strided_slice(*x),
|
||||
[np.array([[1, 2, 3], [4, 5, 6], [7, 8, 9]],
|
||||
|
@ -1348,13 +1348,14 @@ Status AlgebraicSimplifierVisitor::HandleMinimum(HloInstruction* minimum,
|
||||
StatusOr<bool> AlgebraicSimplifier::Run(HloModule* module) {
|
||||
XLA_VLOG_LINES(2,
|
||||
"AlgebraicSimplifier::Run(), before:\n" + module->ToString());
|
||||
bool changed =
|
||||
std::any_of(module->computations().begin(), module->computations().end(),
|
||||
[=](const std::unique_ptr<HloComputation>& computation) {
|
||||
return AlgebraicSimplifierVisitor::Run(
|
||||
computation.get(), is_layout_sensitive_,
|
||||
valid_bitcast_callback_, enable_dot_simplification_);
|
||||
});
|
||||
bool changed = false;
|
||||
for (auto& comp : module->computations()) {
|
||||
if (AlgebraicSimplifierVisitor::Run(comp.get(), is_layout_sensitive_,
|
||||
valid_bitcast_callback_,
|
||||
enable_dot_simplification_)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
XLA_VLOG_LINES(2,
|
||||
"AlgebraicSimplifier::Run(), after:\n" + module->ToString());
|
||||
return changed;
|
||||
|
@ -234,17 +234,15 @@ bool TrySinkReshapeOrTranspose(HloComputation* computation,
|
||||
} // namespace
|
||||
|
||||
StatusOr<bool> ReshapeMover::Run(HloModule* module) {
|
||||
return std::any_of(
|
||||
module->computations().begin(), module->computations().end(),
|
||||
[](const std::unique_ptr<HloComputation>& computation) {
|
||||
std::list<HloInstruction*> postorder =
|
||||
computation->MakeInstructionPostOrder();
|
||||
return std::any_of(postorder.begin(), postorder.end(),
|
||||
[&computation](HloInstruction* instruction) {
|
||||
return TrySinkReshapeOrTranspose(computation.get(),
|
||||
instruction);
|
||||
});
|
||||
});
|
||||
bool changed = false;
|
||||
for (const auto& comp : module->computations()) {
|
||||
for (HloInstruction* instruction : comp->MakeInstructionPostOrder()) {
|
||||
if (TrySinkReshapeOrTranspose(comp.get(), instruction)) {
|
||||
changed = true;
|
||||
}
|
||||
}
|
||||
}
|
||||
return changed;
|
||||
}
|
||||
|
||||
} // namespace xla
|
||||
|
@ -202,5 +202,56 @@ TEST_F(ReshapeMoverTest, ScalarReshapeNotMovedAcrossSelect) {
|
||||
EXPECT_EQ(select, computation->root_instruction());
|
||||
}
|
||||
|
||||
// Tree looks like this:
|
||||
//
|
||||
// add1
|
||||
// |
|
||||
// +- reshape2 - param2
|
||||
// |
|
||||
// +- reshape3 - add0
|
||||
// |
|
||||
// + reshape0 - param0
|
||||
// |
|
||||
// + reshape1 - param1
|
||||
//
|
||||
// We expect reshape{0,1} AND reshape{2,3} to be lifted.
|
||||
TEST_F(ReshapeMoverTest, MultiplePasses) {
|
||||
auto shape1 = ShapeUtil::MakeShape(F32, {1, 8, 1, 7});
|
||||
auto shape2 = ShapeUtil::MakeShape(F32, {8, 7, 1});
|
||||
auto shape3 = ShapeUtil::MakeShape(F32, {8, 7});
|
||||
HloComputation::Builder builder(TestName());
|
||||
auto param0 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(0, shape1, "param0"));
|
||||
auto param1 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(1, shape1, "param1"));
|
||||
auto param2 = builder.AddInstruction(
|
||||
HloInstruction::CreateParameter(2, shape2, "param2"));
|
||||
auto reshape0 =
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(shape2, param0));
|
||||
auto reshape1 =
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(shape2, param1));
|
||||
auto add0 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape2, HloOpcode::kAdd, reshape0, reshape1));
|
||||
auto reshape2 =
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(shape3, param2));
|
||||
auto reshape3 =
|
||||
builder.AddInstruction(HloInstruction::CreateReshape(shape3, add0));
|
||||
auto add1 = builder.AddInstruction(HloInstruction::CreateBinary(
|
||||
shape3, HloOpcode::kAdd, reshape2, reshape3));
|
||||
|
||||
auto module = MakeUnique<HloModule>(TestName());
|
||||
auto computation = module->AddEntryComputation(builder.Build());
|
||||
EXPECT_EQ(add1, computation->root_instruction());
|
||||
EXPECT_TRUE(ReshapeMover().Run(module.get()).ValueOrDie());
|
||||
EXPECT_EQ(HloOpcode::kReshape, computation->root_instruction()->opcode());
|
||||
EXPECT_EQ(HloOpcode::kAdd,
|
||||
computation->root_instruction()->operand(0)->opcode());
|
||||
const auto& add_params =
|
||||
computation->root_instruction()->operand(0)->operands();
|
||||
EXPECT_EQ(2, add_params.size());
|
||||
EXPECT_EQ(HloOpcode::kParameter, add_params[0]->opcode());
|
||||
EXPECT_EQ(HloOpcode::kReshape, add_params[1]->opcode());
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace xla
|
||||
|
@ -44,7 +44,7 @@ def percentile(x,
|
||||
keep_dims=False,
|
||||
validate_args=False,
|
||||
name=None):
|
||||
"""Compute the `q`-th percentile of `x` along leading (sample) dimensions.
|
||||
"""Compute the `q`-th percentile of `x`.
|
||||
|
||||
Given a vector `x`, the `q`-th percentile of `x` is the value `q / 100` of the
|
||||
way from the minimum to the maximum in in a sorted copy of `x`.
|
||||
@ -58,7 +58,7 @@ def percentile(x,
|
||||
|
||||
|
||||
```python
|
||||
# Get 30th percentile with default ('linear') interpolation.
|
||||
# Get 30th percentile with default ('nearest') interpolation.
|
||||
x = [1., 2., 3., 4.]
|
||||
percentile(x, q=30.)
|
||||
==> 2.0
|
||||
@ -91,11 +91,10 @@ def percentile(x,
|
||||
axis: Optional `0-D` or `1-D` integer `Tensor` with constant values.
|
||||
The axis that hold independent samples over which to return the desired
|
||||
percentile. If `None` (the default), treat every dimension as a sample
|
||||
dimension, returning a scalar
|
||||
dimension, returning a scalar.
|
||||
interpolation : {"lower", "higher", "nearest"}. Default: "nearest"
|
||||
This optional parameter specifies the interpolation method to
|
||||
use when the desired quantile lies between two data points
|
||||
`i < j`:
|
||||
use when the desired quantile lies between two data points `i < j`:
|
||||
* lower: `i`.
|
||||
* higher: `j`.
|
||||
* nearest: `i` or `j`, whichever is nearest.
|
||||
|
@ -37,4 +37,4 @@ from tensorflow.contrib.keras.python.keras import utils
|
||||
from tensorflow.contrib.keras.python.keras import wrappers
|
||||
|
||||
|
||||
__version__ = '2.0.0-tf'
|
||||
__version__ = '2.0.2-tf'
|
||||
|
@ -24,18 +24,28 @@ from tensorflow.contrib.keras.python.keras import backend as K
|
||||
from tensorflow.contrib.keras.python.keras.utils.generic_utils import deserialize_keras_object
|
||||
|
||||
|
||||
def softmax(x):
|
||||
def softmax(x, axis=-1):
|
||||
"""Softmax activation function.
|
||||
|
||||
Arguments:
|
||||
x : Tensor.
|
||||
axis: Integer, axis along which the softmax normalization is applied.
|
||||
|
||||
Returns:
|
||||
Tensor, output of softmax transformation.
|
||||
|
||||
Raises:
|
||||
ValueError: In case `dim(x) == 1`.
|
||||
"""
|
||||
ndim = K.ndim(x)
|
||||
if ndim == 2:
|
||||
return K.softmax(x)
|
||||
elif ndim == 3:
|
||||
e = K.exp(x - K.max(x, axis=-1, keepdims=True))
|
||||
s = K.sum(e, axis=-1, keepdims=True)
|
||||
elif ndim > 2:
|
||||
e = K.exp(x - K.max(x, axis=axis, keepdims=True))
|
||||
s = K.sum(e, axis=axis, keepdims=True)
|
||||
return e / s
|
||||
else:
|
||||
raise ValueError('Cannot apply softmax to a tensor '
|
||||
'that is not 2D or 3D. '
|
||||
'Here, ndim=' + str(ndim))
|
||||
raise ValueError('Cannot apply softmax to a tensor that is 1D')
|
||||
|
||||
|
||||
def elu(x, alpha=1.0):
|
||||
|
@ -163,8 +163,8 @@ def ResNet50(include_top=True,
|
||||
specified in your Keras config file.
|
||||
|
||||
Arguments:
|
||||
include_top: whether to include the 3 fully-connected
|
||||
layers at the top of the network.
|
||||
include_top: whether to include the fully-connected
|
||||
layer at the top of the network.
|
||||
weights: one of `None` (random initialization)
|
||||
or "imagenet" (pre-training on ImageNet).
|
||||
input_tensor: optional Keras tensor (i.e. output of `layers.Input()`)
|
||||
|
@ -22,7 +22,6 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from collections import defaultdict
|
||||
import errno
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
@ -270,6 +269,7 @@ def clear_session():
|
||||
reset_uids()
|
||||
_SESSION = None
|
||||
phase = array_ops.placeholder(dtype='bool', name='keras_learning_phase')
|
||||
_GRAPH_LEARNING_PHASES = {}
|
||||
_GRAPH_LEARNING_PHASES[ops.get_default_graph()] = phase
|
||||
|
||||
|
||||
@ -1257,6 +1257,34 @@ def prod(x, axis=None, keepdims=False):
|
||||
return math_ops.reduce_prod(x, reduction_indices=axis, keep_dims=keepdims)
|
||||
|
||||
|
||||
def cumsum(x, axis=0):
|
||||
"""Cumulative sum of the values in a tensor, alongside the specified axis.
|
||||
|
||||
Arguments:
|
||||
x: A tensor or variable.
|
||||
axis: An integer, the axis to compute the sum.
|
||||
|
||||
Returns:
|
||||
A tensor of the cumulative sum of values of `x` along `axis`.
|
||||
"""
|
||||
axis = _normalize_axis(axis, ndim(x))
|
||||
return math_ops.cumsum(x, axis=axis)
|
||||
|
||||
|
||||
def cumprod(x, axis=0):
|
||||
"""Cumulative product of the values in a tensor, alongside the specified axis.
|
||||
|
||||
Arguments:
|
||||
x: A tensor or variable.
|
||||
axis: An integer, the axis to compute the product.
|
||||
|
||||
Returns:
|
||||
A tensor of the cumulative product of values of `x` along `axis`.
|
||||
"""
|
||||
axis = _normalize_axis(axis, ndim(x))
|
||||
return math_ops.cumprod(x, axis=axis)
|
||||
|
||||
|
||||
def var(x, axis=None, keepdims=False):
|
||||
"""Variance of a tensor, alongside the specified axis.
|
||||
|
||||
@ -1330,8 +1358,7 @@ def any(x, axis=None, keepdims=False):
|
||||
"""
|
||||
axis = _normalize_axis(axis, ndim(x))
|
||||
x = math_ops.cast(x, dtypes_module.bool)
|
||||
x = math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
|
||||
return math_ops.cast(x, dtypes_module.uint8)
|
||||
return math_ops.reduce_any(x, reduction_indices=axis, keep_dims=keepdims)
|
||||
|
||||
|
||||
def all(x, axis=None, keepdims=False):
|
||||
@ -1347,8 +1374,7 @@ def all(x, axis=None, keepdims=False):
|
||||
"""
|
||||
axis = _normalize_axis(axis, ndim(x))
|
||||
x = math_ops.cast(x, dtypes_module.bool)
|
||||
x = math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
|
||||
return math_ops.cast(x, dtypes_module.uint8)
|
||||
return math_ops.reduce_all(x, reduction_indices=axis, keep_dims=keepdims)
|
||||
|
||||
|
||||
def argmax(x, axis=-1):
|
||||
@ -1645,7 +1671,7 @@ def normalize_batch_in_training(x, gamma, beta, reduction_axes, epsilon=1e-3):
|
||||
"""
|
||||
mean, var = nn.moments(
|
||||
x, reduction_axes, shift=None, name=None, keep_dims=False)
|
||||
if sorted(reduction_axes) == range(ndim(x))[:-1]:
|
||||
if sorted(reduction_axes) == list(range(ndim(x)))[:-1]:
|
||||
normed = nn.batch_normalization(x, mean, var, beta, gamma, epsilon)
|
||||
else:
|
||||
# need broadcasting
|
||||
@ -2324,8 +2350,8 @@ def rnn(step_function,
|
||||
(no time dimension),
|
||||
containing the initial values for the states used in
|
||||
the step function.
|
||||
go_backwards: boolean. If True, do the iteration over
|
||||
the time dimension in reverse order.
|
||||
go_backwards: boolean. If True, do the iteration over the time
|
||||
dimension in reverse order and return the reversed sequence.
|
||||
mask: binary tensor with shape `(samples, time, 1)`,
|
||||
with a zero for every element that is masked.
|
||||
constants: a list of constant values passed at each step.
|
||||
@ -2414,9 +2440,9 @@ def rnn(step_function,
|
||||
states = return_states
|
||||
successive_outputs.append(output)
|
||||
successive_states.append(states)
|
||||
last_output = successive_outputs[-1]
|
||||
new_states = successive_states[-1]
|
||||
outputs = array_ops.stack(successive_outputs)
|
||||
last_output = successive_outputs[-1]
|
||||
new_states = successive_states[-1]
|
||||
outputs = array_ops.stack(successive_outputs)
|
||||
else:
|
||||
for inp in input_list:
|
||||
output, states = step_function(inp, states + constants)
|
||||
@ -3534,19 +3560,19 @@ def ctc_decode(y_pred, input_length, greedy=True, beam_width=100, top_paths=1):
|
||||
# HIGH ORDER FUNCTIONS
|
||||
|
||||
|
||||
def map_fn(fn, elems, name=None):
|
||||
def map_fn(fn, elems, name=None, dtype=None):
|
||||
"""Map the function fn over the elements elems and return the outputs.
|
||||
|
||||
Arguments:
|
||||
fn: Callable that will be called upon each element in elems
|
||||
elems: tensor
|
||||
name: A string name for the map node in the graph
|
||||
dtype: Output data type.
|
||||
|
||||
Returns:
|
||||
Tensor with first dimension equal to the elems and second depending on
|
||||
fn
|
||||
Tensor with dtype `dtype`.
|
||||
"""
|
||||
return functional_ops.map_fn(fn, elems, name=name)
|
||||
return functional_ops.map_fn(fn, elems, name=name, dtype=dtype)
|
||||
|
||||
|
||||
def foldl(fn, elems, initializer=None, name=None):
|
||||
@ -3560,7 +3586,7 @@ def foldl(fn, elems, initializer=None, name=None):
|
||||
name: A string name for the foldl node in the graph
|
||||
|
||||
Returns:
|
||||
Same type and shape as initializer
|
||||
Tensor with same type and shape as `initializer`.
|
||||
"""
|
||||
return functional_ops.foldl(fn, elems, initializer=initializer, name=name)
|
||||
|
||||
@ -3583,27 +3609,39 @@ def foldr(fn, elems, initializer=None, name=None):
|
||||
|
||||
# Load Keras default configuration from config file if present.
|
||||
_keras_base_dir = os.path.expanduser('~')
|
||||
if not os.access(_keras_base_dir, os.W_OK):
|
||||
_keras_base_dir = '/tmp'
|
||||
_keras_dir = os.path.join(_keras_base_dir, '.keras')
|
||||
if not os.path.exists(_keras_dir):
|
||||
try:
|
||||
os.makedirs(_keras_dir)
|
||||
except OSError as e:
|
||||
if e.errno == errno.EEXIST:
|
||||
pass
|
||||
else:
|
||||
raise
|
||||
_config_path = os.path.expanduser(os.path.join(_keras_dir, 'keras.json'))
|
||||
if os.path.exists(_config_path):
|
||||
_config = json.load(open(_config_path))
|
||||
try:
|
||||
_config = json.load(open(_config_path))
|
||||
except json.decoder.JSONDecodeError:
|
||||
_config = {}
|
||||
_floatx = _config.get('floatx', floatx())
|
||||
assert _floatx in {'float16', 'float32', 'float64'}
|
||||
_epsilon = _config.get('epsilon', epsilon())
|
||||
assert isinstance(_epsilon, float)
|
||||
_backend = backend()
|
||||
_image_data_format = _config.get('image_data_format', image_data_format())
|
||||
assert _image_data_format in {'channels_last', 'channels_first'}
|
||||
set_floatx(_floatx)
|
||||
set_epsilon(_epsilon)
|
||||
set_image_data_format(_image_data_format)
|
||||
|
||||
# Save config file.
|
||||
if os.access(_keras_base_dir, os.W_OK):
|
||||
if not os.path.exists(_keras_dir):
|
||||
try:
|
||||
os.makedirs(_keras_dir)
|
||||
except OSError:
|
||||
# Except potential race conditions
|
||||
# in multi-threaded environments.
|
||||
pass
|
||||
|
||||
if not os.path.exists(_config_path):
|
||||
_config = {
|
||||
'floatx': floatx(),
|
||||
'epsilon': epsilon(),
|
||||
'backend': 'tensorflow',
|
||||
'image_data_format': image_data_format()
|
||||
}
|
||||
with open(_config_path, 'w') as f:
|
||||
f.write(json.dumps(_config, indent=4))
|
||||
|
@ -295,8 +295,14 @@ class Layer(object):
|
||||
# are only applicable to input layers: do not pass these keywords
|
||||
# to non-input layers.
|
||||
allowed_kwargs = {
|
||||
'input_shape', 'batch_input_shape', 'batch_size', 'dtype', 'name',
|
||||
'trainable', 'weights'
|
||||
'input_shape',
|
||||
'batch_input_shape',
|
||||
'batch_size',
|
||||
'dtype',
|
||||
'name',
|
||||
'trainable',
|
||||
'weights',
|
||||
'input_dtype', # legacy
|
||||
}
|
||||
for kwarg in kwargs:
|
||||
if kwarg not in allowed_kwargs:
|
||||
@ -320,8 +326,15 @@ class Layer(object):
|
||||
batch_size = None
|
||||
batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
|
||||
self.batch_input_shape = batch_input_shape
|
||||
dtype = kwargs.get('dtype', K.floatx())
|
||||
|
||||
# Set dtype.
|
||||
dtype = kwargs.get('dtype')
|
||||
if dtype is None:
|
||||
dtype = kwargs.get('input_dtype')
|
||||
if dtype is None:
|
||||
dtype = K.floatx()
|
||||
self.dtype = dtype
|
||||
|
||||
if 'weights' in kwargs:
|
||||
self._initial_weights = kwargs['weights']
|
||||
else:
|
||||
@ -485,11 +498,12 @@ class Layer(object):
|
||||
': expected shape=' + str(spec.shape) +
|
||||
', found shape=' + str(x_shape))
|
||||
|
||||
def call(self, inputs):
|
||||
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
|
||||
"""This is where the layer's logic lives.
|
||||
|
||||
Arguments:
|
||||
inputs: input tensor, or list/tuple of input tensors.
|
||||
inputs: Input tensor, or list/tuple of input tensors.
|
||||
**kwargs: Additional keyword arguments.
|
||||
|
||||
Returns:
|
||||
A tensor or list/tuple of tensors.
|
||||
@ -518,6 +532,8 @@ class Layer(object):
|
||||
ValueError: in case the layer is missing shape information
|
||||
for its `build` call.
|
||||
"""
|
||||
if isinstance(inputs, list):
|
||||
inputs = inputs[:]
|
||||
with K.name_scope(self.name):
|
||||
# Handle laying building (weight creating, input spec locking).
|
||||
if not self.built:
|
||||
@ -1417,7 +1433,7 @@ class Container(Layer):
|
||||
get_weights
|
||||
set_weights
|
||||
get_config
|
||||
get_output_shape_for
|
||||
compute_output_shape
|
||||
|
||||
# Class Methods
|
||||
from_config
|
||||
@ -2029,7 +2045,7 @@ class Container(Layer):
|
||||
for i in range(len(input_shapes)):
|
||||
layer = self.input_layers[i]
|
||||
input_shape = input_shapes[i]
|
||||
# It's an input layer: get_output_shape_for is identity,
|
||||
# It's an input layer: compute_output_shape is identity,
|
||||
# and there is only one node and one tensor output.
|
||||
shape_key = layer.name + '_0_0'
|
||||
layers_to_output_shapes[shape_key] = input_shape
|
||||
|
@ -733,11 +733,12 @@ class Model(Container):
|
||||
loss_functions = []
|
||||
for name in self.output_names:
|
||||
if name not in loss:
|
||||
warnings.warn('Output "' + name + '" missing from loss dictionary. '
|
||||
'We assume this was done on purpose, '
|
||||
'and we will not be expecting '
|
||||
'any data to be passed to "' + name +
|
||||
'" during training.')
|
||||
warnings.warn(
|
||||
'Output "' + name + '" missing from loss dictionary. '
|
||||
'We assume this was done on purpose, '
|
||||
'and we will not be expecting '
|
||||
'any data to be passed to "' + name + '" during training.',
|
||||
stacklevel=2)
|
||||
loss_functions.append(losses.get(loss.get(name)))
|
||||
elif isinstance(loss, list):
|
||||
if len(loss) != len(self.outputs):
|
||||
@ -1202,7 +1203,7 @@ class Model(Container):
|
||||
if batch_index == 0:
|
||||
for batch_out in batch_outs:
|
||||
shape = (samples,) + batch_out.shape[1:]
|
||||
outs.append(np.zeros(shape, dtype=K.floatx()))
|
||||
outs.append(np.zeros(shape, dtype=batch_out.dtype))
|
||||
|
||||
for i, batch_out in enumerate(batch_outs):
|
||||
outs[i][batch_start:batch_end] = batch_out
|
||||
@ -1718,7 +1719,7 @@ class Model(Container):
|
||||
- a tuple (inputs, targets, sample_weights).
|
||||
All arrays should contain the same number of samples.
|
||||
The generator is expected to loop over its data
|
||||
indefinitely. An epoch finishes when `samples_per_epoch`
|
||||
indefinitely. An epoch finishes when `steps_per_epoch`
|
||||
samples have been seen by the model.
|
||||
steps_per_epoch: Total number of steps (batches of samples)
|
||||
to yield from `generator` before declaring one epoch
|
||||
@ -1767,7 +1768,7 @@ class Model(Container):
|
||||
f.close()
|
||||
|
||||
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
|
||||
samples_per_epoch=10000, epochs=10)
|
||||
steps_per_epoch=10000, epochs=10)
|
||||
```
|
||||
|
||||
Raises:
|
||||
@ -2028,7 +2029,8 @@ class Model(Container):
|
||||
steps,
|
||||
max_q_size=10,
|
||||
workers=1,
|
||||
pickle_safe=False):
|
||||
pickle_safe=False,
|
||||
verbose=0):
|
||||
"""Generates predictions for the input samples from a data generator.
|
||||
|
||||
The generator should return the same kind of data as accepted by
|
||||
@ -2048,6 +2050,7 @@ class Model(Container):
|
||||
non picklable arguments to the generator
|
||||
as they can't be passed
|
||||
easily to children processes.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
|
||||
Returns:
|
||||
Numpy array(s) of predictions.
|
||||
@ -2067,6 +2070,9 @@ class Model(Container):
|
||||
enqueuer = GeneratorEnqueuer(generator, pickle_safe=pickle_safe)
|
||||
enqueuer.start(workers=workers, max_q_size=max_q_size)
|
||||
|
||||
if verbose == 1:
|
||||
progbar = Progbar(target=steps)
|
||||
|
||||
while steps_done < steps:
|
||||
generator_output = None
|
||||
while enqueuer.is_running():
|
||||
@ -2103,6 +2109,8 @@ class Model(Container):
|
||||
for i, out in enumerate(outs):
|
||||
all_outs[i].append(out)
|
||||
steps_done += 1
|
||||
if verbose == 1:
|
||||
progbar.update(steps_done)
|
||||
|
||||
finally:
|
||||
if enqueuer is not None:
|
||||
|
@ -45,14 +45,16 @@ class Initializer(object):
|
||||
|
||||
|
||||
class Zeros(Initializer):
|
||||
"""Initializer that generates tensors initialized to 0."""
|
||||
"""Initializer that generates tensors initialized to 0.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.constant(0, shape=shape, dtype=dtype)
|
||||
|
||||
|
||||
class Ones(Initializer):
|
||||
"""Initializer that generates tensors initialized to 1."""
|
||||
"""Initializer that generates tensors initialized to 1.
|
||||
"""
|
||||
|
||||
def __call__(self, shape, dtype=None):
|
||||
return K.constant(1, shape=shape, dtype=dtype)
|
||||
@ -130,7 +132,7 @@ class RandomUniform(Initializer):
|
||||
class TruncatedNormal(Initializer):
|
||||
"""Initializer that generates a truncated normal distribution.
|
||||
|
||||
These values are similar to values from a `random_normal_initializer`
|
||||
These values are similar to values from a `RandomNormal`
|
||||
except that values more than two standard deviations from the mean
|
||||
are discarded and re-drawn. This is the recommended initializer for
|
||||
neural network weights and filters.
|
||||
@ -161,6 +163,7 @@ class VarianceScaling(Initializer):
|
||||
|
||||
With `distribution="normal"`, samples are drawn from a truncated normal
|
||||
distribution centered on zero, with `stddev = sqrt(scale / n)` where n is:
|
||||
|
||||
- number of input units in the weight tensor, if mode = "fan_in"
|
||||
- number of output units, if mode = "fan_out"
|
||||
- average of the numbers of input and output units, if mode = "fan_avg"
|
||||
|
@ -244,7 +244,7 @@ class _Conv(Layer):
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
@ -289,7 +289,7 @@ class Conv1D(_Conv):
|
||||
any `dilation_rate` value != 1.
|
||||
padding: One of `"valid"`, `"causal"` or `"same"` (case-insensitive).
|
||||
`"causal"` results in causal (dilated) convolutions, e.g. output[t]
|
||||
depends solely on input[:t-1]. Useful when modeling temporal data
|
||||
does not depend on input[t+1:]. Useful when modeling temporal data
|
||||
where the model should not violate the temporal order.
|
||||
See [WaveNet: A Generative Model for Raw Audio, section
|
||||
2.1](https://arxiv.org/abs/1609.03499).
|
||||
@ -395,9 +395,9 @@ class Conv2D(_Conv):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -621,7 +621,7 @@ class Conv2DTranspose(Conv2D):
|
||||
|
||||
Arguments:
|
||||
filters: Integer, the dimensionality of the output space
|
||||
(i.e. the number output of filters in the convolution).
|
||||
(i.e. the number of output filters in the convolution).
|
||||
kernel_size: An integer or tuple/list of 2 integers, specifying the
|
||||
width and height of the 2D convolution window.
|
||||
Can be a single integer to specify the same value for
|
||||
@ -637,9 +637,9 @@ class Conv2DTranspose(Conv2D):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -688,7 +688,7 @@ class Conv2DTranspose(Conv2D):
|
||||
kernel_size,
|
||||
strides=(1, 1),
|
||||
padding='valid',
|
||||
data_format='channels_last',
|
||||
data_format=None,
|
||||
activation=None,
|
||||
use_bias=True,
|
||||
kernel_initializer='glorot_uniform',
|
||||
@ -845,9 +845,9 @@ class SeparableConv2D(Conv2D):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -1079,9 +1079,9 @@ class UpSampling2D(Layer):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -1257,7 +1257,7 @@ class ZeroPadding2D(Layer):
|
||||
- If tuple of 2 ints:
|
||||
interpreted as two different
|
||||
symmetric padding values for height and width:
|
||||
`(symmetric_height_pad, symmetrc_width_pad)`.
|
||||
`(symmetric_height_pad, symmetric_width_pad)`.
|
||||
- If tuple of 2 tuples of 2 ints:
|
||||
interpreted as
|
||||
`((top_pad, bottom_pad), (left_pad, right_pad))`
|
||||
@ -1265,9 +1265,9 @@ class ZeroPadding2D(Layer):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -1498,7 +1498,7 @@ class Cropping2D(Layer):
|
||||
- If tuple of 2 ints:
|
||||
interpreted as two different
|
||||
symmetric cropping values for height and width:
|
||||
`(symmetric_height_crop, symmetrc_width_crop)`.
|
||||
`(symmetric_height_crop, symmetric_width_crop)`.
|
||||
- If tuple of 2 tuples of 2 ints:
|
||||
interpreted as
|
||||
`((top_crop, bottom_crop), (left_crop, right_crop))`
|
||||
@ -1506,9 +1506,9 @@ class Cropping2D(Layer):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
|
@ -357,7 +357,7 @@ class ConvLSTM2D(ConvRecurrent2D):
|
||||
self.states = [None, None]
|
||||
|
||||
if self.data_format == 'channels_first':
|
||||
channel_axis = 1
|
||||
channel_axis = 2
|
||||
else:
|
||||
channel_axis = -1
|
||||
if input_shape[channel_axis] is None:
|
||||
|
@ -88,7 +88,7 @@ class Dropout(Layer):
|
||||
"""Applies Dropout to the input.
|
||||
|
||||
Dropout consists in randomly setting
|
||||
a fraction `p` of input units to 0 at each update during training time,
|
||||
a fraction `rate` of input units to 0 at each update during training time,
|
||||
which helps prevent overfitting.
|
||||
|
||||
Arguments:
|
||||
@ -140,7 +140,7 @@ class SpatialDropout1D(Dropout):
|
||||
between feature maps and should be used instead.
|
||||
|
||||
Arguments:
|
||||
p: float between 0 and 1. Fraction of the input units to drop.
|
||||
rate: float between 0 and 1. Fraction of the input units to drop.
|
||||
|
||||
Input shape:
|
||||
3D tensor with shape:
|
||||
@ -775,7 +775,7 @@ class Dense(Layer):
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
|
@ -59,7 +59,8 @@ class LocallyConnected1D(Layer):
|
||||
specifying the stride length of the convolution.
|
||||
Specifying any stride value != 1 is incompatible with specifying
|
||||
any `dilation_rate` value != 1.
|
||||
padding: One of `"valid"` or `"same"` (case-insensitive).
|
||||
padding: Currently only supports `"valid"` (case-insensitive).
|
||||
`"same"` may be supported in the future.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
@ -188,7 +189,7 @@ class LocallyConnected1D(Layer):
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
@ -239,16 +240,15 @@ class LocallyConnected2D(Layer):
|
||||
specifying the strides of the convolution along the width and height.
|
||||
Can be a single integer to specify the same value for
|
||||
all spatial dimensions.
|
||||
Specifying any stride value != 1 is incompatible with specifying
|
||||
any `dilation_rate` value != 1.
|
||||
padding: one of `"valid"` or `"same"` (case-insensitive).
|
||||
padding: Currently only support `"valid"` (case-insensitive).
|
||||
`"same"` will be supported in future.
|
||||
data_format: A string,
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -460,7 +460,7 @@ class LocallyConnected2D(Layer):
|
||||
'kernel_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
'bias_initializer':
|
||||
initializers.serialize(self.kernel_initializer),
|
||||
initializers.serialize(self.bias_initializer),
|
||||
'kernel_regularizer':
|
||||
regularizers.serialize(self.kernel_regularizer),
|
||||
'bias_regularizer':
|
||||
|
@ -41,6 +41,44 @@ class _Merge(Layer):
|
||||
def _merge_function(self, inputs):
|
||||
raise NotImplementedError
|
||||
|
||||
def _compute_elemwise_op_output_shape(self, shape1, shape2):
|
||||
"""Computes the shape of the resultant of an elementwise operation.
|
||||
|
||||
Arguments:
|
||||
shape1: tuple or None. Shape of the first tensor
|
||||
shape2: tuple or None. Shape of the second tensor
|
||||
|
||||
Returns:
|
||||
expected output shape when an element-wise operation is
|
||||
carried out on 2 tensors with shapes shape1 and shape2.
|
||||
tuple or None.
|
||||
|
||||
Raises:
|
||||
ValueError: if shape1 and shape2 are not compatible for
|
||||
element-wise operations.
|
||||
"""
|
||||
if None in [shape1, shape2]:
|
||||
return None
|
||||
elif len(shape1) < len(shape2):
|
||||
return self._compute_elemwise_op_output_shape(shape2, shape1)
|
||||
elif not shape2:
|
||||
return shape1
|
||||
output_shape = list(shape1[:-len(shape2)])
|
||||
for i, j in zip(shape1[-len(shape2):], shape2):
|
||||
if i is None or j is None:
|
||||
output_shape.append(None)
|
||||
elif i == 1:
|
||||
output_shape.append(j)
|
||||
elif j == 1:
|
||||
output_shape.append(i)
|
||||
else:
|
||||
if i != j:
|
||||
raise ValueError('Operands could not be broadcast '
|
||||
'together with shapes ' + str(shape1) + ' ' + str(
|
||||
shape2))
|
||||
output_shape.append(i)
|
||||
return tuple(output_shape)
|
||||
|
||||
def build(self, input_shape):
|
||||
# Used purely for shape validation.
|
||||
if not isinstance(input_shape, list):
|
||||
@ -49,23 +87,107 @@ class _Merge(Layer):
|
||||
raise ValueError('A merge layer should be called '
|
||||
'on a list of at least 2 inputs. '
|
||||
'Got ' + str(len(input_shape)) + ' inputs.')
|
||||
if all([shape is None for shape in input_shape]):
|
||||
return
|
||||
input_shapes = [
|
||||
tuple(tensor_shape.TensorShape(shape).as_list())
|
||||
for shape in input_shape
|
||||
]
|
||||
# TODO(fchollet): handle shapes with None entries.
|
||||
input_shapes_set = set(input_shapes)
|
||||
if None in input_shapes_set:
|
||||
input_shapes_set.remove(None)
|
||||
if len(input_shapes_set) > 1:
|
||||
raise ValueError('Only tensors of same shape can '
|
||||
'be merged by layer' + self.name +
|
||||
' Got input shapes: %s' % input_shapes)
|
||||
batch_sizes = [s[0] for s in input_shape if s is not None]
|
||||
batch_sizes = set(batch_sizes)
|
||||
batch_sizes -= set([None])
|
||||
if len(batch_sizes) > 1:
|
||||
raise ValueError('Can not merge tensors with different '
|
||||
'batch sizes. Got tensors with shapes : ' + str(
|
||||
input_shape))
|
||||
if input_shape[0] is None:
|
||||
output_shape = None
|
||||
else:
|
||||
output_shape = input_shape[0][1:]
|
||||
for i in range(1, len(input_shape)):
|
||||
if input_shape[i] is None:
|
||||
shape = None
|
||||
else:
|
||||
shape = input_shape[i][1:]
|
||||
output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
|
||||
# If the inputs have different ranks, we have to reshape them
|
||||
# to make them broadcastable.
|
||||
if None not in input_shape and len(set(map(len, input_shape))) == 1:
|
||||
self._reshape_required = False
|
||||
else:
|
||||
self._reshape_required = True
|
||||
|
||||
def call(self, inputs):
|
||||
return self._merge_function(inputs)
|
||||
if self._reshape_required:
|
||||
reshaped_inputs = []
|
||||
input_ndims = list(map(K.ndim, inputs))
|
||||
if None not in input_ndims:
|
||||
# If ranks of all inputs are available,
|
||||
# we simply expand each of them at axis=1
|
||||
# until all of them have the same rank.
|
||||
max_ndim = max(input_ndims)
|
||||
for x in inputs:
|
||||
x_ndim = K.ndim(x)
|
||||
for _ in range(max_ndim - x_ndim):
|
||||
x = K.expand_dims(x, 1)
|
||||
reshaped_inputs.append(x)
|
||||
return self._merge_function(reshaped_inputs)
|
||||
else:
|
||||
# Transpose all inputs so that batch size is the last dimension.
|
||||
# (batch_size, dim1, dim2, ... ) -> (dim1, dim2, ... , batch_size)
|
||||
transposed = False
|
||||
for x in inputs:
|
||||
x_ndim = K.ndim(x)
|
||||
if x_ndim is None:
|
||||
x_shape = K.shape(x)
|
||||
batch_size = x_shape[0]
|
||||
new_shape = K.concatenate([x_shape[1:], K.expand_dims(batch_size)])
|
||||
x_transposed = K.reshape(x,
|
||||
K.stack([batch_size, K.prod(x_shape[1:])]))
|
||||
x_transposed = K.permute_dimensions(x_transposed, (1, 0))
|
||||
x_transposed = K.reshape(x_transposed, new_shape)
|
||||
reshaped_inputs.append(x_transposed)
|
||||
transposed = True
|
||||
elif x_ndim > 1:
|
||||
dims = list(range(1, x_ndim)) + [0]
|
||||
reshaped_inputs.append(K.permute_dimensions(x, dims))
|
||||
transposed = True
|
||||
else:
|
||||
# We don't transpose inputs if they are 1D vectors or scalars.
|
||||
reshaped_inputs.append(x)
|
||||
y = self._merge_function(reshaped_inputs)
|
||||
y_ndim = K.ndim(y)
|
||||
if transposed:
|
||||
# If inputs have been transposed, we have to transpose the output too.
|
||||
if y_ndim is None:
|
||||
y_shape = K.shape(y)
|
||||
y_ndim = K.shape(y_shape)[0]
|
||||
batch_size = y_shape[y_ndim - 1]
|
||||
new_shape = K.concatenate(
|
||||
[K.expand_dims(batch_size), y_shape[:y_ndim - 1]])
|
||||
y = K.reshape(y, (-1, batch_size))
|
||||
y = K.permute_dimensions(y, (1, 0))
|
||||
y = K.reshape(y, new_shape)
|
||||
elif y_ndim > 1:
|
||||
dims = [y_ndim - 1] + list(range(y_ndim - 1))
|
||||
y = K.permute_dimensions(y, dims)
|
||||
return y
|
||||
else:
|
||||
return self._merge_function(inputs)
|
||||
|
||||
def compute_output_shape(self, input_shape):
|
||||
if input_shape[0] is None:
|
||||
output_shape = None
|
||||
else:
|
||||
output_shape = input_shape[0][1:]
|
||||
for i in range(1, len(input_shape)):
|
||||
if input_shape[i] is None:
|
||||
shape = None
|
||||
else:
|
||||
shape = input_shape[i][1:]
|
||||
output_shape = self._compute_elemwise_op_output_shape(output_shape, shape)
|
||||
batch_sizes = [s[0] for s in input_shape if s is not None]
|
||||
batch_sizes = set(batch_sizes)
|
||||
batch_sizes -= set([None])
|
||||
if len(batch_sizes) == 1:
|
||||
output_shape = (list(batch_sizes)[0],) + output_shape
|
||||
else:
|
||||
output_shape = (None,) + output_shape
|
||||
return output_shape
|
||||
|
||||
def compute_mask(self, inputs, mask=None):
|
||||
if mask is None:
|
||||
@ -219,8 +341,8 @@ class Concatenate(_Merge):
|
||||
for input_i, mask_i in zip(inputs, mask):
|
||||
if mask_i is None:
|
||||
# Input is unmasked. Append all 1s to masks,
|
||||
# but cast it to uint8 first
|
||||
masks.append(K.cast(K.ones_like(input_i), 'uint8'))
|
||||
# but cast it to bool first
|
||||
masks.append(K.cast(K.ones_like(input_i), 'bool'))
|
||||
elif K.ndim(mask_i) < K.ndim(input_i):
|
||||
# Mask is smaller than the input, expand it
|
||||
masks.append(K.expand_dims(mask_i))
|
||||
|
@ -154,7 +154,7 @@ class BatchNormalization(Layer):
|
||||
broadcast_shape[self.axis] = input_shape[self.axis]
|
||||
|
||||
# Determines whether broadcasting is needed.
|
||||
needs_broadcasting = (sorted(reduction_axes) != range(ndim)[:-1])
|
||||
needs_broadcasting = (sorted(reduction_axes) != list(range(ndim))[:-1])
|
||||
|
||||
normed, mean, variance = K.normalize_batch_in_training(
|
||||
inputs, self.gamma, self.beta, reduction_axes, epsilon=self.epsilon)
|
||||
|
@ -199,9 +199,9 @@ class MaxPooling2D(_Pooling2D):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -255,9 +255,9 @@ class AveragePooling2D(_Pooling2D):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -542,9 +542,9 @@ class GlobalAveragePooling2D(_GlobalPooling2D):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
@ -577,9 +577,9 @@ class GlobalMaxPooling2D(_GlobalPooling2D):
|
||||
one of `channels_last` (default) or `channels_first`.
|
||||
The ordering of the dimensions in the inputs.
|
||||
`channels_last` corresponds to inputs with shape
|
||||
`(batch, width, height, channels)` while `channels_first`
|
||||
`(batch, height, width, channels)` while `channels_first`
|
||||
corresponds to inputs with shape
|
||||
`(batch, channels, width, height)`.
|
||||
`(batch, channels, height, width)`.
|
||||
It defaults to the `image_data_format` value found in your
|
||||
Keras config file at `~/.keras/keras.json`.
|
||||
If you never set it, then it will be "channels_last".
|
||||
|
@ -105,8 +105,16 @@ class Recurrent(Layer):
|
||||
# now model.output_shape == (None, 32)
|
||||
# note: `None` is the batch dimension.
|
||||
|
||||
# for subsequent layers, not need to specify the input size:
|
||||
# for subsequent layers, no need to specify the input size:
|
||||
model.add(LSTM(16))
|
||||
|
||||
# to stack recurrent layers, you must use return_sequences=True
|
||||
# on any recurrent layer that feeds into another recurrent layer.
|
||||
# note that you only need to specify the input size on the first layer.
|
||||
model = Sequential()
|
||||
model.add(LSTM(64, input_dim=64, input_length=10, return_sequences=True))
|
||||
model.add(LSTM(32, return_sequences=True))
|
||||
model.add(LSTM(10))
|
||||
```
|
||||
|
||||
Arguments:
|
||||
@ -116,7 +124,8 @@ class Recurrent(Layer):
|
||||
return_sequences: Boolean. Whether to return the last output
|
||||
in the output sequence, or the full sequence.
|
||||
go_backwards: Boolean (default False).
|
||||
If True, process the input sequence backwards.
|
||||
If True, process the input sequence backwards and return the
|
||||
reversed sequence.
|
||||
stateful: Boolean (default False). If True, the last state
|
||||
for each sample at index i in a batch will be used as initial
|
||||
state for the sample of index i in the following batch.
|
||||
@ -398,6 +407,7 @@ class SimpleRNN(Recurrent):
|
||||
units: Positive integer, dimensionality of the output space.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied
|
||||
If you pass None, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
use_bias: Boolean, whether the layer uses a bias vector.
|
||||
kernel_initializer: Initializer for the `kernel` weights matrix,
|
||||
@ -547,7 +557,7 @@ class SimpleRNN(Recurrent):
|
||||
|
||||
def get_constants(self, inputs, training=None):
|
||||
constants = []
|
||||
if self.implementation == 0 and 0 < self.dropout < 1:
|
||||
if self.implementation != 0 and 0 < self.dropout < 1:
|
||||
input_shape = K.int_shape(inputs)
|
||||
input_dim = input_shape[-1]
|
||||
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
||||
@ -619,7 +629,7 @@ class GRU(Recurrent):
|
||||
Arguments:
|
||||
units: Positive integer, dimensionality of the output space.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied
|
||||
If you pass None, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
recurrent_activation: Activation function to use
|
||||
for the recurrent step.
|
||||
@ -792,7 +802,7 @@ class GRU(Recurrent):
|
||||
|
||||
def get_constants(self, inputs, training=None):
|
||||
constants = []
|
||||
if self.implementation == 0 and 0 < self.dropout < 1:
|
||||
if self.implementation != 0 and 0 < self.dropout < 1:
|
||||
input_shape = K.int_shape(inputs)
|
||||
input_dim = input_shape[-1]
|
||||
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
||||
@ -861,7 +871,7 @@ class GRU(Recurrent):
|
||||
if self.use_bias:
|
||||
x_z = K.bias_add(x_z, self.bias_z)
|
||||
x_r = K.bias_add(x_r, self.bias_r)
|
||||
x_h = K.bias_add(x_r, self.bias_h)
|
||||
x_h = K.bias_add(x_h, self.bias_h)
|
||||
else:
|
||||
raise ValueError('Unknown `implementation` mode.')
|
||||
z = self.recurrent_activation(x_z + K.dot(h_tm1 * rec_dp_mask[0],
|
||||
@ -924,7 +934,7 @@ class LSTM(Recurrent):
|
||||
Arguments:
|
||||
units: Positive integer, dimensionality of the output space.
|
||||
activation: Activation function to use.
|
||||
If you don't specify anything, no activation is applied
|
||||
If you pass None, no activation is applied
|
||||
(ie. "linear" activation: `a(x) = x`).
|
||||
recurrent_activation: Activation function to use
|
||||
for the recurrent step.
|
||||
@ -1127,7 +1137,7 @@ class LSTM(Recurrent):
|
||||
|
||||
def get_constants(self, inputs, training=None):
|
||||
constants = []
|
||||
if self.implementation == 0 and 0 < self.dropout < 1:
|
||||
if self.implementation != 0 and 0 < self.dropout < 1:
|
||||
input_shape = K.int_shape(inputs)
|
||||
input_dim = input_shape[-1]
|
||||
ones = K.ones_like(K.reshape(inputs[:, 0, 0], (-1, 1)))
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
# pylint: disable=protected-access
|
||||
"""Wrapper layers: layers that augment the functionality of another layer.
|
||||
"""
|
||||
from __future__ import absolute_import
|
||||
@ -19,6 +20,7 @@ from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import copy
|
||||
import inspect
|
||||
|
||||
from tensorflow.contrib.keras.python.keras import backend as K
|
||||
from tensorflow.contrib.keras.python.keras.engine import InputSpec
|
||||
@ -70,9 +72,10 @@ class Wrapper(Layer):
|
||||
return dict(list(base_config.items()) + list(config.items()))
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
from tensorflow.contrib.keras.python.keras.layers import deserialize as deserialize_layer # pylint: disable=g-import-not-at-top
|
||||
layer = deserialize_layer(config.pop('layer'))
|
||||
layer = deserialize_layer(
|
||||
config.pop('layer'), custom_objects=custom_objects)
|
||||
return cls(layer, **config)
|
||||
|
||||
|
||||
@ -188,12 +191,15 @@ class Bidirectional(Wrapper):
|
||||
If None, the outputs will not be combined,
|
||||
they will be returned as a list.
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid `merge_mode` argument.
|
||||
|
||||
Examples:
|
||||
|
||||
```python
|
||||
model = Sequential()
|
||||
model.add(Bidirectional(LSTM(10, return_sequences=True), input_shape=(5,
|
||||
10)))
|
||||
10)))
|
||||
model.add(Bidirectional(LSTM(10)))
|
||||
model.add(Dense(5))
|
||||
model.add(Activation('softmax'))
|
||||
@ -242,29 +248,47 @@ class Bidirectional(Wrapper):
|
||||
shape = self.forward_layer._compute_output_shape(input_shape) # pylint: disable=protected-access
|
||||
return [shape, copy.copy(shape)]
|
||||
|
||||
def call(self, inputs, mask=None):
|
||||
y = self.forward_layer.call(inputs, mask)
|
||||
y_rev = self.backward_layer.call(inputs, mask)
|
||||
def call(self, inputs, training=None, mask=None):
|
||||
kwargs = {}
|
||||
func_args = inspect.getargspec(self.layer.call).args
|
||||
if 'training' in func_args:
|
||||
kwargs['training'] = training
|
||||
if 'mask' in func_args:
|
||||
kwargs['mask'] = mask
|
||||
|
||||
y = self.forward_layer.call(inputs, **kwargs)
|
||||
y_rev = self.backward_layer.call(inputs, **kwargs)
|
||||
if self.return_sequences:
|
||||
y_rev = K.reverse(y_rev, 1)
|
||||
if self.merge_mode == 'concat':
|
||||
return K.concatenate([y, y_rev])
|
||||
output = K.concatenate([y, y_rev])
|
||||
elif self.merge_mode == 'sum':
|
||||
return y + y_rev
|
||||
output = y + y_rev
|
||||
elif self.merge_mode == 'ave':
|
||||
return (y + y_rev) / 2
|
||||
output = (y + y_rev) / 2
|
||||
elif self.merge_mode == 'mul':
|
||||
return y * y_rev
|
||||
output = y * y_rev
|
||||
elif self.merge_mode is None:
|
||||
return [y, y_rev]
|
||||
output = [y, y_rev]
|
||||
|
||||
# Properly set learning phase
|
||||
if 0 < self.layer.dropout + self.layer.recurrent_dropout:
|
||||
if self.merge_mode is None:
|
||||
for out in output:
|
||||
out._uses_learning_phase = True
|
||||
else:
|
||||
output._uses_learning_phase = True
|
||||
return output
|
||||
|
||||
def reset_states(self):
|
||||
self.forward_layer.reset_states()
|
||||
self.backward_layer.reset_states()
|
||||
|
||||
def build(self, input_shape):
|
||||
self.forward_layer.build(input_shape)
|
||||
self.backward_layer.build(input_shape)
|
||||
with K.name_scope(self.forward_layer.name):
|
||||
self.forward_layer.build(input_shape)
|
||||
with K.name_scope(self.backward_layer.name):
|
||||
self.backward_layer.build(input_shape)
|
||||
self.built = True
|
||||
|
||||
def compute_mask(self, inputs, mask):
|
||||
|
@ -43,12 +43,15 @@ def binary_accuracy(y_true, y_pred):
|
||||
|
||||
|
||||
def categorical_accuracy(y_true, y_pred):
|
||||
return K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1))
|
||||
return K.cast(
|
||||
K.equal(K.argmax(y_true, axis=-1), K.argmax(y_pred, axis=-1)), K.floatx())
|
||||
|
||||
|
||||
def sparse_categorical_accuracy(y_true, y_pred):
|
||||
return K.equal(
|
||||
K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1), K.floatx()))
|
||||
return K.cast(
|
||||
K.equal(
|
||||
K.max(y_true, axis=-1), K.cast(K.argmax(y_pred, axis=-1),
|
||||
K.floatx())), K.floatx())
|
||||
|
||||
|
||||
def top_k_categorical_accuracy(y_true, y_pred, k=5):
|
||||
|
@ -207,7 +207,7 @@ def load_model(filepath, custom_objects=None):
|
||||
ValueError: In case of an invalid savefile.
|
||||
"""
|
||||
if h5py is None:
|
||||
raise ImportError('`save_model` requires h5py.')
|
||||
raise ImportError('`load_model` requires h5py.')
|
||||
|
||||
if not custom_objects:
|
||||
custom_objects = {}
|
||||
@ -1006,7 +1006,7 @@ class Sequential(Model):
|
||||
steps_per_epoch: Total number of steps (batches of samples)
|
||||
to yield from `generator` before declaring one epoch
|
||||
finished and starting the next epoch. It should typically
|
||||
be equal to the number of unique samples if your dataset
|
||||
be equal to the number of unique samples of your dataset
|
||||
divided by the batch size.
|
||||
epochs: Integer, total number of iterations on the data.
|
||||
verbose: Verbosity mode, 0, 1, or 2.
|
||||
@ -1017,8 +1017,10 @@ class Sequential(Model):
|
||||
- A tuple (inputs, targets, sample_weights).
|
||||
validation_steps: Only relevant if `validation_data`
|
||||
is a generator.
|
||||
Number of samples to use from validation generator
|
||||
at the end of every epoch.
|
||||
Number of steps to yield from validation generator
|
||||
at the end of every epoch. It should typically
|
||||
be equal to the number of unique samples of your
|
||||
validation dataset divided by the batch size.
|
||||
class_weight: Dictionary mapping class indices to a weight
|
||||
for the class.
|
||||
max_q_size: Maximum size for the generator queue
|
||||
@ -1050,7 +1052,7 @@ class Sequential(Model):
|
||||
# and labels, from each line in the file
|
||||
x, y = process_line(line)
|
||||
yield (x, y)
|
||||
f.close()
|
||||
f.close()
|
||||
|
||||
model.fit_generator(generate_arrays_from_file('/my_file.txt'),
|
||||
samples_per_epoch=10000, epochs=10)
|
||||
@ -1119,7 +1121,8 @@ class Sequential(Model):
|
||||
steps,
|
||||
max_q_size=10,
|
||||
workers=1,
|
||||
pickle_safe=False):
|
||||
pickle_safe=False,
|
||||
verbose=0):
|
||||
"""Generates predictions for the input samples from a data generator.
|
||||
|
||||
The generator should return the same kind of data as accepted by
|
||||
@ -1136,6 +1139,7 @@ class Sequential(Model):
|
||||
relies on multiprocessing, you should not pass
|
||||
non picklable arguments to the generator
|
||||
as they can't be passed easily to children processes.
|
||||
verbose: verbosity mode, 0 or 1.
|
||||
|
||||
Returns:
|
||||
A Numpy array of predictions.
|
||||
@ -1147,7 +1151,8 @@ class Sequential(Model):
|
||||
steps,
|
||||
max_q_size=max_q_size,
|
||||
workers=workers,
|
||||
pickle_safe=pickle_safe)
|
||||
pickle_safe=pickle_safe,
|
||||
verbose=verbose)
|
||||
|
||||
def get_config(self):
|
||||
config = []
|
||||
@ -1159,9 +1164,9 @@ class Sequential(Model):
|
||||
return copy.deepcopy(config)
|
||||
|
||||
@classmethod
|
||||
def from_config(cls, config):
|
||||
def from_config(cls, config, custom_objects=None):
|
||||
model = cls()
|
||||
for conf in config:
|
||||
layer = layer_module.deserialize(conf)
|
||||
layer = layer_module.deserialize(conf, custom_objects=custom_objects)
|
||||
model.add(layer)
|
||||
return model
|
||||
|
@ -785,7 +785,7 @@ class Iterator(object):
|
||||
index_array = np.random.permutation(n)
|
||||
|
||||
current_index = (self.batch_index * batch_size) % n
|
||||
if n >= current_index + batch_size:
|
||||
if n > current_index + batch_size:
|
||||
current_batch_size = batch_size
|
||||
self.batch_index += 1
|
||||
else:
|
||||
|
@ -172,7 +172,8 @@ def deserialize_keras_object(identifier,
|
||||
else:
|
||||
fn = module_objects.get(function_name)
|
||||
if fn is None:
|
||||
raise ValueError('Unknown ' + printable_module_name, ':' + class_name)
|
||||
raise ValueError('Unknown ' + printable_module_name,
|
||||
':' + function_name)
|
||||
return fn
|
||||
else:
|
||||
raise ValueError('Could not interpret serialized ' + printable_module_name +
|
||||
@ -215,6 +216,8 @@ def func_load(code, defaults=None, closure=None, globs=None):
|
||||
"""
|
||||
if isinstance(code, (tuple, list)): # unpack previous dump
|
||||
code, defaults, closure = code
|
||||
if isinstance(defaults, list):
|
||||
defaults = tuple(defaults)
|
||||
code = marshal.loads(code.encode('raw_unicode_escape'))
|
||||
if globs is None:
|
||||
globs = globals()
|
||||
|
@ -171,7 +171,7 @@ def count_total_params(layers, layer_set=None):
|
||||
[K.count_params(p) for p in layer.trainable_weights])
|
||||
non_trainable_count += np.sum(
|
||||
[K.count_params(p) for p in layer.non_trainable_weights])
|
||||
return trainable_count, non_trainable_count
|
||||
return int(trainable_count), int(non_trainable_count)
|
||||
|
||||
|
||||
def convert_all_kernels_in_model(model):
|
||||
|
@ -194,6 +194,36 @@ class KerasClassifier(BaseWrapper):
|
||||
"""Implementation of the scikit-learn classifier API for Keras.
|
||||
"""
|
||||
|
||||
def fit(self, x, y, **kwargs):
|
||||
"""Constructs a new model with `build_fn` & fit the model to `(x, y)`.
|
||||
|
||||
Arguments:
|
||||
x : array-like, shape `(n_samples, n_features)`
|
||||
Training samples where n_samples in the number of samples
|
||||
and n_features is the number of features.
|
||||
y : array-like, shape `(n_samples,)` or `(n_samples, n_outputs)`
|
||||
True labels for X.
|
||||
**kwargs: dictionary arguments
|
||||
Legal arguments are the arguments of `Sequential.fit`
|
||||
|
||||
Returns:
|
||||
history : object
|
||||
details about the training history at each epoch.
|
||||
|
||||
Raises:
|
||||
ValueError: In case of invalid shape for `y` argument.
|
||||
"""
|
||||
y = np.array(y)
|
||||
if len(y.shape) == 2 and y.shape[1] > 1:
|
||||
self.classes_ = np.arange(y.shape[1])
|
||||
elif (len(y.shape) == 2 and y.shape[1] == 1) or len(y.shape) == 1:
|
||||
self.classes_ = np.unique(y)
|
||||
y = np.searchsorted(self.classes_, y)
|
||||
else:
|
||||
raise ValueError('Invalid shape for y: ' + str(y.shape))
|
||||
self.n_classes_ = len(self.classes_)
|
||||
return super(KerasClassifier, self).fit(x, y, **kwargs)
|
||||
|
||||
def predict(self, x, **kwargs):
|
||||
"""Returns the class predictions for the given test data.
|
||||
|
||||
@ -210,7 +240,8 @@ class KerasClassifier(BaseWrapper):
|
||||
Class predictions.
|
||||
"""
|
||||
kwargs = self.filter_sk_params(Sequential.predict_classes, kwargs)
|
||||
return self.model.predict_classes(x, **kwargs)
|
||||
classes = self.model.predict_classes(x, **kwargs)
|
||||
return self.classes_[classes]
|
||||
|
||||
def predict_proba(self, x, **kwargs):
|
||||
"""Returns class probability estimates for the given test data.
|
||||
@ -261,6 +292,7 @@ class KerasClassifier(BaseWrapper):
|
||||
compute accuracy. You should pass `metrics=["accuracy"]` to
|
||||
the `.compile()` method of the model.
|
||||
"""
|
||||
y = np.searchsorted(self.classes_, y)
|
||||
kwargs = self.filter_sk_params(Sequential.evaluate, kwargs)
|
||||
|
||||
loss_name = self.model.loss
|
||||
|
@ -22,11 +22,13 @@ from six.moves import xrange # pylint: disable=redefined-builtin
|
||||
from tensorflow.contrib.framework.python.framework import tensor_util as contrib_tensor_util
|
||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import ops
|
||||
from tensorflow.python.framework import sparse_tensor
|
||||
from tensorflow.python.framework import tensor_shape
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import clip_ops
|
||||
from tensorflow.python.ops import control_flow_ops
|
||||
from tensorflow.python.ops import data_flow_ops
|
||||
from tensorflow.python.ops import embedding_ops
|
||||
@ -555,8 +557,13 @@ def _sampled_scattered_embedding_lookup_sparse(params,
|
||||
name=name_scope)
|
||||
|
||||
|
||||
def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids,
|
||||
sp_weights, partition_strategy="mod", name=None, combiner=None,
|
||||
def embedding_lookup_sparse_with_distributed_aggregation(
|
||||
params,
|
||||
sp_ids,
|
||||
sp_weights,
|
||||
partition_strategy="mod",
|
||||
name=None,
|
||||
combiner=None,
|
||||
max_norm=None):
|
||||
"""Computes embeddings for the given ids and weights.
|
||||
|
||||
@ -638,8 +645,13 @@ def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids,
|
||||
|
||||
weights = None if ignore_weights else sp_weights.values
|
||||
embeddings = _embedding_lookup_with_distributed_aggregation(
|
||||
params, ids, partition_strategy=partition_strategy, max_norm=max_norm,
|
||||
weights=weights, idx=idx, segment_ids=segment_ids)
|
||||
params,
|
||||
ids,
|
||||
partition_strategy=partition_strategy,
|
||||
max_norm=max_norm,
|
||||
weights=weights,
|
||||
idx=idx,
|
||||
segment_ids=segment_ids)
|
||||
# Set weights to all one if ignore weights.
|
||||
if ignore_weights:
|
||||
weights = array_ops.fill([array_ops.shape(segment_ids)[0]], 1)
|
||||
@ -648,13 +660,13 @@ def embedding_lookup_sparse_with_distributed_aggregation(params, sp_ids,
|
||||
# Reshape weights.
|
||||
ones = array_ops.fill(
|
||||
array_ops.expand_dims(array_ops.rank(embeddings) - 1, 0), 1)
|
||||
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones],
|
||||
0)
|
||||
bcast_weights_shape = array_ops.concat([array_ops.shape(weights), ones], 0)
|
||||
orig_weights_shape = weights.get_shape()
|
||||
weights = array_ops.reshape(weights, bcast_weights_shape)
|
||||
if embeddings.get_shape().ndims is not None:
|
||||
weights.set_shape(orig_weights_shape.concatenate(
|
||||
[1 for _ in range(embeddings.get_shape().ndims - 1)]))
|
||||
weights.set_shape(
|
||||
orig_weights_shape.concatenate(
|
||||
[1 for _ in range(embeddings.get_shape().ndims - 1)]))
|
||||
|
||||
if combiner == "mean":
|
||||
weight_sum = math_ops.segment_sum(weights, segment_ids)
|
||||
@ -677,16 +689,23 @@ def _do_gather(params, ids, validate_indices=True, name=None):
|
||||
params, ids, name=name, validate_indices=validate_indices)
|
||||
|
||||
|
||||
def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
partition_strategy="mod", name=None, validate_indices=True, max_norm=None,
|
||||
weights=None, idx=None, segment_ids=None):
|
||||
""" Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
|
||||
def _embedding_lookup_with_distributed_aggregation(params,
|
||||
ids,
|
||||
partition_strategy="mod",
|
||||
name=None,
|
||||
validate_indices=True,
|
||||
max_norm=None,
|
||||
weights=None,
|
||||
idx=None,
|
||||
segment_ids=None):
|
||||
"""Lookup helper for embedding_lookup_sparse_with_distributed_aggregation."""
|
||||
if params is None or params == []: # pylint: disable=g-explicit-bool-comparison
|
||||
raise ValueError("Need at least one param")
|
||||
if isinstance(params, variables.PartitionedVariable):
|
||||
params = list(params) # Iterate to get the underlying Variables.
|
||||
if not isinstance(params, list):
|
||||
params = [params]
|
||||
|
||||
def maybe_normalize(x):
|
||||
if max_norm is not None:
|
||||
if x.get_shape().ndims is not None:
|
||||
@ -695,18 +714,18 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
ndims = array_ops.size(array_ops.shape(x))
|
||||
return clip_ops.clip_by_norm(x, max_norm, axes=list(range(1, ndims)))
|
||||
return x
|
||||
|
||||
with ops.name_scope(name, "embedding_lookup_with_distributed_aggregation",
|
||||
params + [ids]) as name:
|
||||
params + [ids]) as name:
|
||||
np = len(params) # Number of partitions
|
||||
# Preserve the resource variable status to avoid accidental dense reads.
|
||||
if not any(isinstance(p, resource_variable_ops.ResourceVariable)
|
||||
for p in params):
|
||||
if not any(
|
||||
isinstance(p, resource_variable_ops.ResourceVariable) for p in params):
|
||||
params = ops.convert_n_to_tensor_or_indexed_slices(params, name="params")
|
||||
if np == 1:
|
||||
with ops.colocate_with(params[0]):
|
||||
ret = maybe_normalize(
|
||||
_do_gather(
|
||||
params[0], ids, validate_indices=validate_indices))
|
||||
_do_gather(params[0], ids, validate_indices=validate_indices))
|
||||
ignore_weights = weights is None
|
||||
if not ignore_weights:
|
||||
if weights.dtype != ret.dtype:
|
||||
@ -720,8 +739,9 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
weights = array_ops.reshape(weights, bcast_weights_shape)
|
||||
# Set weights shape after reshape
|
||||
if ret.get_shape().ndims is not None:
|
||||
weights.set_shape(orig_weights_shape.concatenate(
|
||||
[1 for _ in range(ret.get_shape().ndims - 1)]))
|
||||
weights.set_shape(
|
||||
orig_weights_shape.concatenate(
|
||||
[1 for _ in range(ret.get_shape().ndims - 1)]))
|
||||
ret *= weights
|
||||
return math_ops.segment_sum(ret, segment_ids, name=name)
|
||||
else:
|
||||
@ -757,18 +777,16 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
ids_per_partition = num_total_ids // np
|
||||
extras = num_total_ids % np
|
||||
|
||||
p_assignments = math_ops.maximum(
|
||||
flat_ids // (ids_per_partition + 1),
|
||||
(flat_ids - extras) // ids_per_partition)
|
||||
p_assignments = math_ops.maximum(flat_ids // (ids_per_partition + 1), (
|
||||
flat_ids - extras) // ids_per_partition)
|
||||
|
||||
# Emulate a conditional using a boolean indicator tensor
|
||||
is_in_first_extras_partitions = math_ops.cast(
|
||||
p_assignments < extras, flat_ids.dtype)
|
||||
new_ids = (
|
||||
is_in_first_extras_partitions * (
|
||||
flat_ids % (ids_per_partition + 1)) +
|
||||
(1 - is_in_first_extras_partitions) * (
|
||||
(flat_ids - extras) % ids_per_partition))
|
||||
is_in_first_extras_partitions = math_ops.cast(p_assignments < extras,
|
||||
flat_ids.dtype)
|
||||
new_ids = (is_in_first_extras_partitions * (flat_ids %
|
||||
(ids_per_partition + 1)) +
|
||||
(1 - is_in_first_extras_partitions) * (
|
||||
(flat_ids - extras) % ids_per_partition))
|
||||
else:
|
||||
raise ValueError("Unrecognized partition strategy: " +
|
||||
partition_strategy)
|
||||
@ -786,8 +804,8 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result.append(
|
||||
_do_gather(params[p], gather_ids[p],
|
||||
validate_indices=validate_indices))
|
||||
_do_gather(
|
||||
params[p], gather_ids[p], validate_indices=validate_indices))
|
||||
|
||||
ignore_weights = weights is None
|
||||
if not ignore_weights:
|
||||
@ -802,17 +820,21 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
if element_shape.is_fully_defined():
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
|
||||
array_ops.concat(
|
||||
[array_ops.shape(pindices[p]), element_shape], 0))
|
||||
partitioned_result[p] = array_ops.reshape(
|
||||
partitioned_result[p],
|
||||
array_ops.concat([array_ops.shape(pindices[p]), element_shape],
|
||||
0))
|
||||
else:
|
||||
with ops.colocate_with(params[0]):
|
||||
params_shape = array_ops.shape(params[0])
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
partitioned_result[p] = array_ops.reshape(partitioned_result[p],
|
||||
array_ops.concat([array_ops.shape(pindices[p]),
|
||||
array_ops.slice(params_shape, [1], [-1])], 0))
|
||||
partitioned_result[p] = array_ops.reshape(
|
||||
partitioned_result[p],
|
||||
array_ops.concat([
|
||||
array_ops.shape(pindices[p]), array_ops.slice(
|
||||
params_shape, [1], [-1])
|
||||
], 0))
|
||||
# Normalize each partition result.
|
||||
for p in xrange(np):
|
||||
with ops.colocate_with(params[p]):
|
||||
@ -823,7 +845,7 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
with ops.colocate_with(params[p]):
|
||||
if partitioned_weight[p].dtype != partitioned_result[p].dtype:
|
||||
partitioned_weight[p] = math_ops.cast(partitioned_weight[p],
|
||||
partitioned_result[p].dtype)
|
||||
partitioned_result[p].dtype)
|
||||
# Reshape partition weights.
|
||||
ones = array_ops.fill(
|
||||
array_ops.expand_dims(
|
||||
@ -834,9 +856,12 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
partitioned_weight[p] = array_ops.reshape(partitioned_weight[p],
|
||||
bcast_weights_shape)
|
||||
if partitioned_result[p].get_shape().ndims is not None:
|
||||
partitioned_weight[p].set_shape(orig_weights_shape.concatenate(
|
||||
[1 for _ in range(
|
||||
partitioned_result[p].get_shape().ndims - 1)]))
|
||||
partitioned_weight[p].set_shape(
|
||||
orig_weights_shape.concatenate([
|
||||
1
|
||||
for _ in range(partitioned_result[p].get_shape().ndims -
|
||||
1)
|
||||
]))
|
||||
partitioned_result[p] *= partitioned_weight[p]
|
||||
partitioned_segment_ids = []
|
||||
for p in xrange(np):
|
||||
@ -874,5 +899,7 @@ def _embedding_lookup_with_distributed_aggregation(params, ids,
|
||||
concat_segment_ids = array_ops.concat(partitioned_segment_ids, 0)
|
||||
concat_partitioned_result = array_ops.concat(partitioned_result, 0)
|
||||
return math_ops.unsorted_segment_sum(
|
||||
concat_partitioned_result, concat_segment_ids,
|
||||
math_ops.reduce_max(concat_segment_ids) + 1, name=name)
|
||||
concat_partitioned_result,
|
||||
concat_segment_ids,
|
||||
math_ops.reduce_max(concat_segment_ids) + 1,
|
||||
name=name)
|
||||
|
@ -31,8 +31,9 @@ from tensorflow.python.framework import dtypes
|
||||
from tensorflow.python.framework import errors_impl
|
||||
from tensorflow.python.framework import random_seed
|
||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import array_ops
|
||||
from tensorflow.python.ops import gradient_checker
|
||||
from tensorflow.python.ops import init_ops
|
||||
from tensorflow.python.ops import math_ops
|
||||
from tensorflow.python.ops import partitioned_variables
|
||||
from tensorflow.python.platform import test
|
||||
@ -145,8 +146,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
self.assertAllClose(
|
||||
embedding_lookup_result,
|
||||
[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
|
||||
[0] * 4, embedding_weights[0][2],
|
||||
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
|
||||
[0] * 4, embedding_weights[0][2], (
|
||||
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_partitioned(self):
|
||||
with self.test_session():
|
||||
@ -171,8 +172,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
|
||||
embedding_weights, sparse_ids)
|
||||
embedding_weights = [
|
||||
constant_op.constant(
|
||||
w, dtype=dtypes.float64) for w in embedding_weights
|
||||
constant_op.constant(w, dtype=dtypes.float64)
|
||||
for w in embedding_weights
|
||||
]
|
||||
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
|
||||
embedding_weights, sparse_ids, sparse_weights)
|
||||
@ -185,11 +186,10 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids, sparse_weights).eval())
|
||||
|
||||
self.assertAllClose(
|
||||
embedding_lookup_result,
|
||||
[[(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) /
|
||||
3.0, [0] * 4, [0] * 4],
|
||||
[embedding_weights[0][2], [0] * 4, [0] * 4]])
|
||||
self.assertAllClose(embedding_lookup_result, [[
|
||||
(1.0 * embedding_weights[0][0] + 2.0 * embedding_weights[0][1]) / 3.0,
|
||||
[0] * 4, [0] * 4
|
||||
], [embedding_weights[0][2], [0] * 4, [0] * 4]])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_3d_return_special_vector(self):
|
||||
with self.test_session():
|
||||
@ -215,14 +215,13 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
embedding_lookup_result = (embedding_ops.safe_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_ids, None).eval())
|
||||
|
||||
self.assertAllClose(
|
||||
embedding_lookup_result,
|
||||
[[(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4,
|
||||
[0] * 4], [
|
||||
embedding_weights[0][2],
|
||||
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0,
|
||||
[0] * 4
|
||||
]])
|
||||
self.assertAllClose(embedding_lookup_result, [[(
|
||||
embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4, [
|
||||
0
|
||||
] * 4], [
|
||||
embedding_weights[0][2],
|
||||
(embedding_weights[0][0] + embedding_weights[0][1]) / 2.0, [0] * 4
|
||||
]])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_3d_partitioned(self):
|
||||
with self.test_session():
|
||||
@ -233,13 +232,12 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
embedding_weights, sparse_ids, None).eval())
|
||||
|
||||
embedding_weights = list(itertools.chain(*embedding_weights))
|
||||
self.assertAllClose(embedding_lookup_result,
|
||||
[[(embedding_weights[0] + embedding_weights[1]) / 2.0,
|
||||
[0] * 4, [0] * 4], [
|
||||
embedding_weights[2],
|
||||
(embedding_weights[0] + embedding_weights[1]) /
|
||||
2.0, [0] * 4
|
||||
]])
|
||||
self.assertAllClose(embedding_lookup_result, [[
|
||||
(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4, [0] * 4
|
||||
], [
|
||||
embedding_weights[2],
|
||||
(embedding_weights[0] + embedding_weights[1]) / 2.0, [0] * 4
|
||||
]])
|
||||
|
||||
def test_safe_embedding_lookup_sparse_3d_partitioned_inconsistent_weights(
|
||||
self):
|
||||
@ -251,8 +249,8 @@ class SafeEmbeddingLookupSparseTest(test.TestCase):
|
||||
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
|
||||
embedding_weights, sparse_ids)
|
||||
embedding_weights = [
|
||||
constant_op.constant(
|
||||
w, dtype=dtypes.float64) for w in embedding_weights
|
||||
constant_op.constant(w, dtype=dtypes.float64)
|
||||
for w in embedding_weights
|
||||
]
|
||||
self.assertRaises(ValueError, embedding_ops.safe_embedding_lookup_sparse,
|
||||
embedding_weights, sparse_ids, sparse_weights)
|
||||
@ -301,8 +299,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
self.assertAllEqual(embedding_lookup_result[0],
|
||||
embedding_lookup_result[1])
|
||||
# Different embedding expected for different value.
|
||||
embedding_diff = np.min((embedding_lookup_result[2] -
|
||||
embedding_lookup_result[0])**2)
|
||||
embedding_diff = np.min(
|
||||
(embedding_lookup_result[2] - embedding_lookup_result[0])**2)
|
||||
self.assertGreater(embedding_diff, 0)
|
||||
|
||||
def test_scattered_embedding_coverage(self):
|
||||
@ -320,8 +318,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
def test_scattered_embedding_multi_dimension(self):
|
||||
with self.test_session():
|
||||
embedding_weights = self._random_weights()
|
||||
values = constant_op.constant(
|
||||
[["foo", "bar", "bar"], ["bar", "bar", "foo"]])
|
||||
values = constant_op.constant([["foo", "bar", "bar"],
|
||||
["bar", "bar", "foo"]])
|
||||
|
||||
embedding_lookup_result = embedding_ops.scattered_embedding_lookup(
|
||||
embedding_weights, values, dimension=10).eval()
|
||||
@ -340,8 +338,8 @@ class ScatteredEmbeddingLookupTest(test.TestCase):
|
||||
|
||||
embedding_lookup_result = (
|
||||
embedding_ops.scattered_embedding_lookup_sparse(
|
||||
embedding_weights, sparse_tensor, dimension=5, combiner="mean")
|
||||
.eval())
|
||||
embedding_weights, sparse_tensor, dimension=5,
|
||||
combiner="mean").eval())
|
||||
|
||||
self.assertAllEqual(embedding_lookup_result.shape, [5, 5])
|
||||
# Same non-zero embedding for the empty rows filled with a default value.
|
||||
@ -433,8 +431,8 @@ class SampledScatteredEmbeddingLookupTest(test.TestCase):
|
||||
def test_hashed_embedding_multi_dimension(self):
|
||||
with self.test_session():
|
||||
embedding_weights = self._random_weights()
|
||||
values = constant_op.constant(
|
||||
[["foo", "bar", "bar"], ["bar", "bar", "foo"]])
|
||||
values = constant_op.constant([["foo", "bar", "bar"],
|
||||
["bar", "bar", "foo"]])
|
||||
sampled_candidates = constant_op.constant(
|
||||
[[[1, 3, 4, 6], [1, 7, 8, 9], [1, 7, 8, 9]],
|
||||
[[1, 7, 8, 9], [1, 7, 8, 9], [1, 3, 4, 6]]])
|
||||
@ -491,8 +489,8 @@ class SampledScatteredEmbeddingLookupSparseTest(test.TestCase):
|
||||
result = embedding_ops._sampled_scattered_embedding_lookup_sparse(
|
||||
params, sp_values, dimension=5, hash_key=self._hash_key)
|
||||
|
||||
self.assertAllClose(result.eval(), [[0., 0., 0., 0., 0.],
|
||||
[.3, .2, .2, .3, .1],
|
||||
self.assertAllClose(result.eval(), [[0., 0., 0., 0.,
|
||||
0.], [.3, .2, .2, .3, .1],
|
||||
[0., 0., 0., 0., 0.]])
|
||||
|
||||
def test_output_values_with_sampled_candidates(self):
|
||||
@ -631,8 +629,8 @@ def _EmbeddingResult(params,
|
||||
else:
|
||||
partition = extras + (i - threshold) // ids_per_partition
|
||||
offset = (i - threshold) % ids_per_partition
|
||||
val = np.copy(params[_PName(partition) + ":0"][
|
||||
offset, :]) * weight_value
|
||||
val = np.copy(
|
||||
params[_PName(partition) + ":0"][offset, :]) * weight_value
|
||||
else:
|
||||
assert False
|
||||
if value_aggregation is None:
|
||||
@ -707,19 +705,19 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
|
||||
grouped_ignored_weights = self._GroupByBatchEntry(
|
||||
np.ones(np.sum(vals_per_batch_entry)), vals_per_batch_entry)
|
||||
|
||||
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 5],
|
||||
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
|
||||
[True, False]):
|
||||
for num_shards, combiner, dtype, ignore_weights in itertools.product(
|
||||
[1, 5], ["sum", "mean", "sqrtn"], [dtypes.float32,
|
||||
dtypes.float64], [True, False]):
|
||||
|
||||
with self.test_session():
|
||||
p, params, feed_dict = _EmbeddingParams(
|
||||
num_shards, vocab_size, shape=param_shape, dtype=dtype)
|
||||
embedding_sum = \
|
||||
embedding_ops.embedding_lookup_sparse_with_distributed_aggregation(
|
||||
p,
|
||||
sp_ids,
|
||||
None if ignore_weights else sp_weights,
|
||||
combiner=combiner)
|
||||
p,
|
||||
sp_ids,
|
||||
None if ignore_weights else sp_weights,
|
||||
combiner=combiner)
|
||||
|
||||
self.assertEqual(embedding_sum.get_shape().as_list(),
|
||||
expected_lookup_result_shape)
|
||||
@ -731,8 +729,8 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
|
||||
grouped_ids,
|
||||
num_shards,
|
||||
vocab_size,
|
||||
weight_vals=grouped_ignored_weights if ignore_weights else
|
||||
grouped_weights)
|
||||
weight_vals=grouped_ignored_weights
|
||||
if ignore_weights else grouped_weights)
|
||||
if combiner == "mean":
|
||||
np_embedding_sum /= np.reshape(np_weight_sum, (batch_size, 1, 1))
|
||||
if combiner == "sqrtn":
|
||||
@ -744,12 +742,12 @@ class EmbeddingLookupSparseWithDistributedAggregationTest(test.TestCase):
|
||||
vocab_size = 12
|
||||
batch_size = 4
|
||||
param_shape = [2, 3]
|
||||
sp_ids, sp_weights, _, _, _ = (
|
||||
self._RandomIdsAndWeights(batch_size, vocab_size))
|
||||
sp_ids, sp_weights, _, _, _ = (self._RandomIdsAndWeights(
|
||||
batch_size, vocab_size))
|
||||
|
||||
for num_shards, combiner, dtype, ignore_weights in itertools.product([1, 3],
|
||||
["sum", "mean", "sqrtn"], [dtypes.float32, dtypes.float64],
|
||||
[True, False]):
|
||||
for num_shards, combiner, dtype, ignore_weights in itertools.product(
|
||||
[1, 3], ["sum", "mean", "sqrtn"], [dtypes.float32,
|
||||
dtypes.float64], [True, False]):
|
||||
with self.test_session():
|
||||
x, params, _ = _EmbeddingParams(
|
||||
num_shards, vocab_size, shape=param_shape, dtype=dtype)
|
||||
|
@ -1942,6 +1942,7 @@ def separable_convolution2d(
|
||||
dtype=dtype,
|
||||
initializer=biases_initializer,
|
||||
regularizer=biases_regularizer,
|
||||
trainable=trainable,
|
||||
collections=biases_collections)
|
||||
outputs = nn.bias_add(outputs, biases)
|
||||
|
||||
|
@ -2979,6 +2979,20 @@ class SeparableConv2dTest(test.TestCase):
|
||||
sess.run(init_op)
|
||||
sess.run(net, feed_dict={images_placeholder: images})
|
||||
|
||||
def testTrainableFlagIsPassedOn(self):
|
||||
for trainable in [True, False]:
|
||||
for num_filters in [None, 8]:
|
||||
with ops.Graph().as_default():
|
||||
input_size = [5, 10, 12, 3]
|
||||
|
||||
images = random_ops.random_uniform(input_size, seed=1)
|
||||
layers_lib.separable_conv2d(
|
||||
images, num_filters, [3, 3], 1, trainable=trainable)
|
||||
model_variables = variables.get_model_variables()
|
||||
trainable_variables = variables_lib.trainable_variables()
|
||||
for model_variable in model_variables:
|
||||
self.assertEqual(trainable, model_variable in trainable_variables)
|
||||
|
||||
|
||||
class ScaleGradientTests(test.TestCase):
|
||||
"""Simple tests of the scale_gradient function."""
|
||||
|
@ -22,6 +22,7 @@ from __future__ import print_function
|
||||
# pylint: disable=unused-import
|
||||
from tensorflow.python.estimator.inputs.queues.feeding_functions import _ArrayFeedFn
|
||||
from tensorflow.python.estimator.inputs.queues.feeding_functions import _enqueue_data as enqueue_data
|
||||
from tensorflow.python.estimator.inputs.queues.feeding_functions import _GeneratorFeedFn
|
||||
from tensorflow.python.estimator.inputs.queues.feeding_functions import _OrderedDictNumpyFeedFn
|
||||
from tensorflow.python.estimator.inputs.queues.feeding_functions import _PandasFeedFn
|
||||
from tensorflow.python.estimator.inputs.queues.feeding_functions import _GeneratorFeedFn
|
||||
|
@ -26,6 +26,7 @@ import six
|
||||
from tensorflow.contrib import framework as contrib_framework
|
||||
from tensorflow.contrib.framework import get_graph_from_inputs
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import metric_key
|
||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||
from tensorflow.python.estimator import model_fn as core_model_fn_lib
|
||||
from tensorflow.python.estimator.export import export_output as core_export_lib
|
||||
@ -255,12 +256,20 @@ class ModelFnOps(
|
||||
export_outputs_dict = {key: _export_output(*val) for key, val in
|
||||
output_alternatives.items()}
|
||||
|
||||
def _get_eval_metric_ops():
|
||||
"""Returns self.eval_metric_ops without loss metric."""
|
||||
result = {}
|
||||
for key, value in six.iteritems(self.eval_metric_ops):
|
||||
if key != metric_key.MetricKey.LOSS:
|
||||
result[key] = value
|
||||
return result
|
||||
|
||||
return core_model_fn_lib.EstimatorSpec(
|
||||
mode=mode,
|
||||
predictions=self.predictions,
|
||||
loss=self.loss,
|
||||
train_op=self.train_op,
|
||||
eval_metric_ops=self.eval_metric_ops,
|
||||
eval_metric_ops=_get_eval_metric_ops(),
|
||||
export_outputs=export_outputs_dict,
|
||||
training_chief_hooks=self.training_chief_hooks,
|
||||
training_hooks=self.training_hooks,
|
||||
|
@ -18,6 +18,7 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import six
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||
from tensorflow.contrib.learn.python.learn.estimators import model_fn
|
||||
@ -51,19 +52,26 @@ class ModelFnopsTest(test.TestCase):
|
||||
predictions=predictions,
|
||||
loss=constant_op.constant([1]),
|
||||
train_op=control_flow_ops.no_op(),
|
||||
eval_metric_ops={"metric_key": (control_flow_ops.no_op(),
|
||||
control_flow_ops.no_op())},
|
||||
eval_metric_ops={
|
||||
"metric_key": (constant_op.constant(1.), control_flow_ops.no_op()),
|
||||
"loss": (constant_op.constant(1.), control_flow_ops.no_op()),
|
||||
},
|
||||
# zzz
|
||||
training_chief_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||
training_hooks=[basic_session_run_hooks.StepCounterHook()],
|
||||
output_alternatives=output_alternatives,
|
||||
scaffold=monitored_session.Scaffold())
|
||||
|
||||
def assertEquals_except_export(self, model_fn_ops, estimator_spec):
|
||||
def assertEquals_except_export_and_eval_loss(
|
||||
self, model_fn_ops, estimator_spec):
|
||||
expected_eval_metric_ops = {}
|
||||
for key, value in six.iteritems(model_fn_ops.eval_metric_ops):
|
||||
if key != "loss":
|
||||
expected_eval_metric_ops[key] = value
|
||||
self.assertEqual(model_fn_ops.predictions, estimator_spec.predictions)
|
||||
self.assertEqual(model_fn_ops.loss, estimator_spec.loss)
|
||||
self.assertEqual(model_fn_ops.train_op, estimator_spec.train_op)
|
||||
self.assertEqual(model_fn_ops.eval_metric_ops,
|
||||
self.assertEqual(expected_eval_metric_ops,
|
||||
estimator_spec.eval_metric_ops)
|
||||
self.assertEqual(model_fn_ops.training_chief_hooks,
|
||||
estimator_spec.training_chief_hooks)
|
||||
@ -75,7 +83,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, None)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
def testEstimatorSpec_export_regression_with_scores(self):
|
||||
predictions = self.create_predictions()
|
||||
@ -84,7 +92,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
@ -103,7 +111,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
@ -119,7 +127,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
@ -140,7 +148,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
@ -162,7 +170,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
@ -182,7 +190,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
@ -203,7 +211,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
classification_output = estimator_spec.export_outputs[
|
||||
@ -221,7 +229,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
logistic_output = estimator_spec.export_outputs["logistic_head"]
|
||||
@ -240,7 +248,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
model_fn_ops = self.create_model_fn_ops(predictions, output_alternatives)
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER)
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
unspecified_output = estimator_spec.export_outputs["unspecified_head"]
|
||||
@ -259,7 +267,7 @@ class ModelFnopsTest(test.TestCase):
|
||||
|
||||
estimator_spec = model_fn_ops.estimator_spec(model_fn.ModeKeys.INFER,
|
||||
"regression_head")
|
||||
self.assertEquals_except_export(model_fn_ops, estimator_spec)
|
||||
self.assertEquals_except_export_and_eval_loss(model_fn_ops, estimator_spec)
|
||||
|
||||
with session.Session():
|
||||
regression_output = estimator_spec.export_outputs["regression_head"]
|
||||
|
@ -18,8 +18,9 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
from types import FunctionType, GeneratorType
|
||||
from collections import Container
|
||||
from types import FunctionType
|
||||
from types import GeneratorType
|
||||
|
||||
from tensorflow.contrib.learn.python.learn.dataframe.queues import feeding_functions
|
||||
|
||||
@ -33,7 +34,7 @@ def generator_input_fn(x,
|
||||
num_threads=1):
|
||||
"""Returns input function that would dicts of numpy arrays
|
||||
yielded from a generator.
|
||||
|
||||
|
||||
It is assumed that every dict yielded from the dictionary represents
|
||||
a single sample. The generator should consume a single epoch of the data.
|
||||
|
||||
@ -82,47 +83,44 @@ def generator_input_fn(x,
|
||||
KeyError: `key` mismatch between dicts emitted from `x()`
|
||||
"""
|
||||
if not isinstance(x, FunctionType):
|
||||
raise TypeError('x must be generator function; got {}'.format(
|
||||
type(x).__name__))
|
||||
raise TypeError(
|
||||
'x must be generator function; got {}'.format(type(x).__name__))
|
||||
generator = x()
|
||||
if not isinstance(generator, GeneratorType):
|
||||
raise TypeError('x() must be generator; got {}'.format(
|
||||
type(generator).__name__))
|
||||
raise TypeError(
|
||||
'x() must be generator; got {}'.format(type(generator).__name__))
|
||||
data = next(generator)
|
||||
if not isinstance(data, dict):
|
||||
raise TypeError('x() must yield dict; got {}'.format(
|
||||
type(data).__name__))
|
||||
raise TypeError('x() must yield dict; got {}'.format(type(data).__name__))
|
||||
input_keys = sorted(next(x()).keys())
|
||||
if target_key is not None:
|
||||
if isinstance(target_key, str):
|
||||
target_key = [target_key]
|
||||
elif isinstance(target_key, Container):
|
||||
elif isinstance(target_key, Container):
|
||||
for item in target_key:
|
||||
if not isinstance(item, str):
|
||||
raise TypeError(
|
||||
'target_key must be str or Container of str; got {}'.format(
|
||||
type(item).__name__))
|
||||
raise TypeError('target_key must be str or Container of str; got {}'.
|
||||
format(type(item).__name__))
|
||||
if item not in input_keys:
|
||||
raise KeyError(
|
||||
'target_key not in yielded dict. Expected {} keys; got {}'.format(
|
||||
input_keys, item))
|
||||
else:
|
||||
raise TypeError(
|
||||
'target_key must be str or Container of str; got {}'.format(
|
||||
type(target_key).__name__))
|
||||
raise TypeError('target_key must be str or Container of str; got {}'.
|
||||
format(type(target_key).__name__))
|
||||
|
||||
def _generator_input_fn():
|
||||
"""generator input function."""
|
||||
queue = feeding_functions.enqueue_data(
|
||||
x,
|
||||
queue_capacity,
|
||||
shuffle=shuffle,
|
||||
num_threads=num_threads,
|
||||
enqueue_size=batch_size,
|
||||
num_epochs=num_epochs)
|
||||
x,
|
||||
queue_capacity,
|
||||
shuffle=shuffle,
|
||||
num_threads=num_threads,
|
||||
enqueue_size=batch_size,
|
||||
num_epochs=num_epochs)
|
||||
|
||||
features = (queue.dequeue_many(batch_size) if num_epochs is None
|
||||
else queue.dequeue_up_to(batch_size))
|
||||
features = (queue.dequeue_many(batch_size)
|
||||
if num_epochs is None else queue.dequeue_up_to(batch_size))
|
||||
if not isinstance(features, list):
|
||||
features = [features]
|
||||
features = dict(zip(input_keys, features))
|
||||
|
@ -35,17 +35,24 @@ from tensorflow.python.training import queue_runner_impl
|
||||
|
||||
|
||||
class GeneratorIoTest(test.TestCase):
|
||||
|
||||
def testGeneratorInputFn(self):
|
||||
|
||||
def generator():
|
||||
for index in range(2):
|
||||
yield {'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32}
|
||||
yield {
|
||||
'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32
|
||||
}
|
||||
|
||||
with self.test_session() as session:
|
||||
input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key='label', batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator,
|
||||
target_key='label',
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
num_epochs=1)
|
||||
features, target = input_fn()
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
@ -71,7 +78,7 @@ class GeneratorIoTest(test.TestCase):
|
||||
|
||||
with self.test_session() as session:
|
||||
input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
|
||||
features = input_fn()
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
@ -91,15 +98,20 @@ class GeneratorIoTest(test.TestCase):
|
||||
|
||||
def generator():
|
||||
for index in range(2):
|
||||
yield {'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32,
|
||||
'label2': np.ones(1) * index - 64,
|
||||
}
|
||||
yield {
|
||||
'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32,
|
||||
'label2': np.ones(1) * index - 64,
|
||||
}
|
||||
|
||||
with self.test_session() as session:
|
||||
input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=['label','label2'], batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator,
|
||||
target_key=['label', 'label2'],
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
num_epochs=1)
|
||||
features, target = input_fn()
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
@ -108,8 +120,10 @@ class GeneratorIoTest(test.TestCase):
|
||||
res = session.run([features, target])
|
||||
self.assertAllEqual(res[0]['a'], np.asarray([0, 1]).reshape(-1, 1))
|
||||
self.assertAllEqual(res[0]['b'], np.asarray([32, 33]).reshape(-1, 1))
|
||||
self.assertAllEqual(res[1]['label'], np.asarray([-32, -31]).reshape(-1, 1))
|
||||
self.assertAllEqual(res[1]['label2'], np.asarray([-64, -63]).reshape(-1, 1))
|
||||
self.assertAllEqual(res[1]['label'], np.asarray([-32, -31]).reshape(
|
||||
-1, 1))
|
||||
self.assertAllEqual(res[1]['label2'],
|
||||
np.asarray([-64, -63]).reshape(-1, 1))
|
||||
|
||||
session.run([features])
|
||||
with self.assertRaises(errors.OutOfRangeError):
|
||||
@ -122,22 +136,34 @@ class GeneratorIoTest(test.TestCase):
|
||||
|
||||
def generator():
|
||||
for index in range(100):
|
||||
yield {'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32}
|
||||
yield {
|
||||
'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32
|
||||
}
|
||||
|
||||
with self.test_session() as session:
|
||||
input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key="label", batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator,
|
||||
target_key='label',
|
||||
batch_size=2,
|
||||
shuffle=False,
|
||||
num_epochs=1)
|
||||
features, target = input_fn()
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
threads = queue_runner_impl.start_queue_runners(session, coord=coord)
|
||||
|
||||
res = session.run([features, target])
|
||||
self.assertAllEqual(res[0]['a'], np.vstack((np.zeros((10, 10)), np.ones((10, 10)))).reshape(2, 10, 10))
|
||||
self.assertAllEqual(res[0]['b'], np.vstack((np.zeros((5, 5)), np.ones((5, 5)))).reshape(2, 5, 5) + 32)
|
||||
self.assertAllEqual(res[1], np.vstack((np.zeros((3, 3)), np.ones((3, 3)))).reshape(2, 3, 3) - 32)
|
||||
self.assertAllEqual(res[0]['a'],
|
||||
np.vstack((np.zeros((10, 10)), np.ones(
|
||||
(10, 10)))).reshape(2, 10, 10))
|
||||
self.assertAllEqual(res[0]['b'],
|
||||
np.vstack((np.zeros((5, 5)), np.ones(
|
||||
(5, 5)))).reshape(2, 5, 5) + 32)
|
||||
self.assertAllEqual(res[1],
|
||||
np.vstack((np.zeros((3, 3)), np.ones(
|
||||
(3, 3)))).reshape(2, 3, 3) - 32)
|
||||
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
@ -147,82 +173,97 @@ class GeneratorIoTest(test.TestCase):
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(TypeError, 'x must be generator function'):
|
||||
failing_input_fn = generator_io.generator_input_fn(
|
||||
x, batch_size=2, shuffle=False, num_epochs=1)
|
||||
x, batch_size=2, shuffle=False, num_epochs=1)
|
||||
failing_input_fn()
|
||||
|
||||
def testGeneratorInputFnWithXAsNonGenerator(self):
|
||||
|
||||
def generator():
|
||||
return np.arange(32, 36)
|
||||
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(TypeError, "x\(\) must be generator"):
|
||||
with self.assertRaisesRegexp(TypeError, 'x\(\) must be generator'):
|
||||
failing_input_fn = generator_io.generator_input_fn(
|
||||
generator, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, batch_size=2, shuffle=False, num_epochs=1)
|
||||
failing_input_fn()
|
||||
|
||||
def testGeneratorInputFnWithXAsNonGeneratorYieldingDicts(self):
|
||||
|
||||
def generator():
|
||||
yield np.arange(32, 36)
|
||||
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(TypeError, "x\(\) must yield dict"):
|
||||
with self.assertRaisesRegexp(TypeError, 'x\(\) must yield dict'):
|
||||
failing_input_fn = generator_io.generator_input_fn(
|
||||
generator, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, batch_size=2, shuffle=False, num_epochs=1)
|
||||
failing_input_fn()
|
||||
|
||||
def testGeneratorInputFNWithTargetLabelNotString(self):
|
||||
|
||||
def generator():
|
||||
for index in range(2):
|
||||
yield {'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32}
|
||||
yield {
|
||||
'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32
|
||||
}
|
||||
|
||||
y = np.arange(32, 36)
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
|
||||
' Container of str'):
|
||||
' Container of str'):
|
||||
failing_input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
failing_input_fn()
|
||||
|
||||
def testGeneratorInputFNWithTargetLabelListNotString(self):
|
||||
|
||||
def generator():
|
||||
for index in range(2):
|
||||
yield {'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32}
|
||||
yield {
|
||||
'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32
|
||||
}
|
||||
|
||||
y = ["label", np.arange(10)]
|
||||
y = ['label', np.arange(10)]
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(TypeError, 'target_key must be str or'
|
||||
' Container of str'):
|
||||
' Container of str'):
|
||||
failing_input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
failing_input_fn()
|
||||
|
||||
def testGeneratorInputFNWithTargetLabelNotInDict(self):
|
||||
|
||||
def generator():
|
||||
for index in range(2):
|
||||
yield {'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32}
|
||||
yield {
|
||||
'a': np.ones((10, 10)) * index,
|
||||
'b': np.ones((5, 5)) * index + 32,
|
||||
'label': np.ones((3, 3)) * index - 32
|
||||
}
|
||||
|
||||
y = ["label", "target"]
|
||||
y = ['label', 'target']
|
||||
with self.test_session():
|
||||
with self.assertRaisesRegexp(KeyError,
|
||||
'target_key not in yielded dict'):
|
||||
with self.assertRaisesRegexp(KeyError, 'target_key not in yielded dict'):
|
||||
failing_input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, target_key=y, batch_size=2, shuffle=False, num_epochs=1)
|
||||
failing_input_fn()
|
||||
|
||||
def testGeneratorInputFnWithNoTargetKey(self):
|
||||
|
||||
def generator():
|
||||
for index in range(2):
|
||||
yield {'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32}
|
||||
yield {
|
||||
'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32
|
||||
}
|
||||
|
||||
with self.test_session() as session:
|
||||
input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
|
||||
features = input_fn()
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
@ -241,15 +282,18 @@ class GeneratorIoTest(test.TestCase):
|
||||
coord.join(threads)
|
||||
|
||||
def testGeneratorInputFnWithBatchLargerthanData(self):
|
||||
|
||||
def generator():
|
||||
for index in range(2):
|
||||
yield {'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32}
|
||||
yield {
|
||||
'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32
|
||||
}
|
||||
|
||||
with self.test_session() as session:
|
||||
input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
|
||||
generator, target_key=None, batch_size=4, shuffle=False, num_epochs=1)
|
||||
features = input_fn()
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
@ -268,19 +312,24 @@ class GeneratorIoTest(test.TestCase):
|
||||
coord.join(threads)
|
||||
|
||||
def testGeneratorInputFnWithMismatchinGeneratorKeys(self):
|
||||
|
||||
def generator():
|
||||
index = 0
|
||||
yield {'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32}
|
||||
yield {
|
||||
'a': np.ones(1) * index,
|
||||
'b': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32
|
||||
}
|
||||
index = 1
|
||||
yield {'a': np.ones(1) * index,
|
||||
'c': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32}
|
||||
yield {
|
||||
'a': np.ones(1) * index,
|
||||
'c': np.ones(1) * index + 32,
|
||||
'label': np.ones(1) * index - 32
|
||||
}
|
||||
|
||||
with self.test_session() as session:
|
||||
input_fn = generator_io.generator_input_fn(
|
||||
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
|
||||
generator, target_key=None, batch_size=2, shuffle=False, num_epochs=1)
|
||||
features = input_fn()
|
||||
|
||||
coord = coordinator.Coordinator()
|
||||
@ -290,9 +339,10 @@ class GeneratorIoTest(test.TestCase):
|
||||
session.run([features])
|
||||
|
||||
with self.assertRaisesRegex(KeyError, 'key mismatch between dicts emitted'
|
||||
' by GenFunExpected'):
|
||||
' by GenFunExpected'):
|
||||
coord.request_stop()
|
||||
coord.join(threads)
|
||||
|
||||
|
||||
if __name__ == '__main__':
|
||||
test.main()
|
||||
|
@ -99,8 +99,13 @@ class ExternalOptimizerInterface(object):
|
||||
slice(start, end) for start, end in zip(accumulated_dims[:-1],
|
||||
accumulated_dims[1:])]
|
||||
|
||||
def minimize(self, session=None, feed_dict=None, fetches=None,
|
||||
step_callback=None, loss_callback=None, **run_kwargs):
|
||||
def minimize(self,
|
||||
session=None,
|
||||
feed_dict=None,
|
||||
fetches=None,
|
||||
step_callback=None,
|
||||
loss_callback=None,
|
||||
**run_kwargs):
|
||||
"""Minimize a scalar `Tensor`.
|
||||
|
||||
Variables subject to optimization are updated in-place at the end of
|
||||
@ -120,7 +125,7 @@ class ExternalOptimizerInterface(object):
|
||||
flattened into a single vector.
|
||||
loss_callback: A function to be called every time the loss and gradients
|
||||
are computed, with evaluated fetches supplied as positional arguments.
|
||||
run_kwargs: kwargs to pass to `session.run`.
|
||||
**run_kwargs: kwargs to pass to `session.run`.
|
||||
"""
|
||||
session = session or ops.get_default_session()
|
||||
feed_dict = feed_dict or {}
|
||||
@ -161,9 +166,10 @@ class ExternalOptimizerInterface(object):
|
||||
for packing_slice in self._packing_slices]
|
||||
|
||||
# Set optimization variables to their new values.
|
||||
session.run(self._var_updates,
|
||||
feed_dict=dict(zip(self._update_placeholders, var_vals)),
|
||||
**run_kwargs)
|
||||
session.run(
|
||||
self._var_updates,
|
||||
feed_dict=dict(zip(self._update_placeholders, var_vals)),
|
||||
**run_kwargs)
|
||||
|
||||
def _minimize(self, initial_val, loss_grad_func, equality_funcs,
|
||||
equality_grad_funcs, inequality_funcs, inequality_grad_funcs,
|
||||
|
@ -12,7 +12,6 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
# ==============================================================================
|
||||
|
||||
"""Seq2seq loss operations for use in sequence models.
|
||||
"""
|
||||
|
||||
@ -28,16 +27,21 @@ from tensorflow.python.ops import nn_ops
|
||||
__all__ = ["sequence_loss"]
|
||||
|
||||
|
||||
def sequence_loss(logits, targets, weights,
|
||||
average_across_timesteps=True, average_across_batch=True,
|
||||
softmax_loss_function=None, name=None):
|
||||
"""Weighted cross-entropy loss for a sequence of logits. Depending on the
|
||||
values of `average_across_timesteps` and `average_across_batch`, the return
|
||||
Tensor will have rank 0, 1, or 2 as these arguments reduce the cross-entropy
|
||||
at each target, which has shape `[batch_size, sequence_length]`, over their
|
||||
respective dimensions. For example, if `average_across_timesteps` is `True`
|
||||
and `average_across_batch` is `False`, then the return Tensor will have shape
|
||||
`[batch_size]`.
|
||||
def sequence_loss(logits,
|
||||
targets,
|
||||
weights,
|
||||
average_across_timesteps=True,
|
||||
average_across_batch=True,
|
||||
softmax_loss_function=None,
|
||||
name=None):
|
||||
"""Weighted cross-entropy loss for a sequence of logits.
|
||||
|
||||
Depending on the values of `average_across_timesteps` and
|
||||
`average_across_batch`, the return Tensor will have rank 0, 1, or 2 as these
|
||||
arguments reduce the cross-entropy at each target, which has shape
|
||||
`[batch_size, sequence_length]`, over their respective dimensions. For
|
||||
example, if `average_across_timesteps` is `True` and `average_across_batch`
|
||||
is `False`, then the return Tensor will have shape `[batch_size]`.
|
||||
|
||||
Args:
|
||||
logits: A Tensor of shape
|
||||
|
@ -274,14 +274,14 @@ Status BaseGPUDevice::FillContextMap(const Graph* graph,
|
||||
DeviceContextMap* device_context_map) {
|
||||
VLOG(2) << "FillContextMap";
|
||||
|
||||
const auto num_streams = streams_.size();
|
||||
const size_t num_streams = streams_.size();
|
||||
// Special case for single stream.
|
||||
if (num_streams == 1) {
|
||||
return Status::OK();
|
||||
}
|
||||
const int64 before = Env::Default()->NowMicros();
|
||||
gpu_stream_util::AssignStreamsOpts opts;
|
||||
opts.max_streams = num_streams;
|
||||
opts.max_streams = static_cast<int32>(num_streams);
|
||||
std::unordered_map<int, int> node_to_stream_id;
|
||||
TF_RETURN_IF_ERROR(
|
||||
gpu_stream_util::AssignStreams(graph, opts, &node_to_stream_id));
|
||||
@ -519,7 +519,7 @@ void BaseGPUDevice::ReinitializeGpuDevice(OpKernelContext* context,
|
||||
Status BaseGPUDeviceFactory::CreateDevices(const SessionOptions& options,
|
||||
const string& name_prefix,
|
||||
std::vector<Device*>* devices) {
|
||||
int n = INT_MAX;
|
||||
size_t n = INT_MAX;
|
||||
auto iter = options.config.device_count().find("GPU");
|
||||
if (iter != options.config.device_count().end()) {
|
||||
n = iter->second;
|
||||
@ -971,7 +971,7 @@ Status BaseGPUDeviceFactory::GetValidDeviceIds(
|
||||
continue;
|
||||
}
|
||||
|
||||
int new_id = ids->size();
|
||||
size_t new_id = ids->size();
|
||||
ids->push_back(visible_gpu_id);
|
||||
|
||||
LOG(INFO) << "Creating TensorFlow device (/gpu:" << new_id << ") -> "
|
||||
|
@ -37,12 +37,12 @@ class TEST_EventMgrHelper {
|
||||
StopPollingLoop();
|
||||
}
|
||||
|
||||
int queue_size() {
|
||||
size_t queue_size() {
|
||||
mutex_lock l(em_->mu_);
|
||||
return em_->used_events_.size();
|
||||
}
|
||||
|
||||
int free_size() {
|
||||
size_t free_size() {
|
||||
mutex_lock l(em_->mu_);
|
||||
return em_->free_events_.size();
|
||||
}
|
||||
|
@ -299,6 +299,13 @@ Status ShapeRefiner::ExtractConstantSubgraph(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// Don't constant fold enter/exit currently either, as it's easy to end
|
||||
// up with a partial frame.
|
||||
if (IsEnter(current_node) || IsExit(current_node)) {
|
||||
*is_constant_graph = false;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
// If there is nothing more to recurse down, see if
|
||||
// the generator node is a constant.
|
||||
if (current_node->num_inputs() == 0) {
|
||||
|
@ -24,6 +24,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "tensorflow/core/lib/gtl/inlined_vector.h"
|
||||
#include "tensorflow/core/lib/gtl/map_util.h"
|
||||
#include "tensorflow/core/util/equal_graph_def.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -652,6 +653,36 @@ string DebugStringWhole(const GraphDef& gdef) {
|
||||
return ret;
|
||||
}
|
||||
|
||||
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2) {
|
||||
// NOTE(skyewm): Using MessageDifferencer would be better here, but that is
|
||||
// currently not included in tensorflow/core/platform/default/protobuf.h, so
|
||||
// play fast and loose here. I don't see anything in OpDef that should allow
|
||||
// multiple equivalent string serializations, with the exception of
|
||||
// AttrValues, which can vary for tensor values (see AreAttrValuesEqual()
|
||||
// comments).
|
||||
string sig1, sig2;
|
||||
f1.signature().SerializeToString(&sig1);
|
||||
f2.signature().SerializeToString(&sig2);
|
||||
if (sig1 != sig2) return false;
|
||||
|
||||
if (f1.attr().size() != f2.attr().size()) return false;
|
||||
for (auto iter1 : f1.attr()) {
|
||||
auto iter2 = f2.attr().find(iter1.first);
|
||||
if (iter2 == f2.attr().end()) return false;
|
||||
if (!AreAttrValuesEqual(iter1.second, iter2->second)) return false;
|
||||
}
|
||||
|
||||
if (!EqualRepeatedNodeDef(f1.node_def(), f2.node_def(), nullptr)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
std::map<string, string> ret1(f1.ret().begin(), f1.ret().end());
|
||||
std::map<string, string> ret2(f2.ret().begin(), f2.ret().end());
|
||||
if (ret1 != ret2) return false;
|
||||
|
||||
return true;
|
||||
}
|
||||
|
||||
string Canonicalize(const string& funcname,
|
||||
const InstantiateAttrValueMap& attrs) {
|
||||
std::vector<string> entries;
|
||||
@ -802,6 +833,17 @@ Status FunctionLibraryDefinition::AddLibrary(
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
Status FunctionLibraryDefinition::AddLibrary(
|
||||
const FunctionDefLibrary& lib_def) {
|
||||
for (const FunctionDef& fdef : lib_def.function()) {
|
||||
TF_RETURN_IF_ERROR(AddFunctionDef(fdef));
|
||||
}
|
||||
for (const GradientDef& grad : lib_def.gradient()) {
|
||||
TF_RETURN_IF_ERROR(AddGradientDef(grad));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
string FunctionLibraryDefinition::FindGradient(const string& func) const {
|
||||
return gtl::FindWithDefault(func_grad_, func, "");
|
||||
}
|
||||
|
@ -230,6 +230,10 @@ string DebugString(const GraphDef& instantiated_func_def);
|
||||
// its supporting functions defined in its library).
|
||||
string DebugStringWhole(const GraphDef& gdef);
|
||||
|
||||
// Returns true if f1 == f2. Compares all fields, including descriptions. Order
|
||||
// of NodeDefs doesn't matter.
|
||||
bool FunctionDefsEqual(const FunctionDef& f1, const FunctionDef& f2);
|
||||
|
||||
// Returns a canonicalized string for the instantiation of the
|
||||
// function of the given "name" and attributes "attrs".
|
||||
//
|
||||
@ -303,6 +307,9 @@ class FunctionLibraryDefinition : public OpRegistryInterface {
|
||||
// Adds the functions and gradients in 'other' to this function library.
|
||||
Status AddLibrary(const FunctionLibraryDefinition& other);
|
||||
|
||||
// Adds the functions and gradients in 'lib_def' to this function library.
|
||||
Status AddLibrary(const FunctionDefLibrary& lib_def);
|
||||
|
||||
// If the gradient function for 'func' is specified explicitly in
|
||||
// the library, returns the gradient function name. Otherwise,
|
||||
// returns an empty string.
|
||||
|
@ -1107,4 +1107,36 @@ TEST(FunctionLibraryDefinitionTest, GetAttr_Gradient) {
|
||||
EXPECT_EQ(annotation, false); // WXPlusB has no custom gradient.
|
||||
}
|
||||
|
||||
// TODO(skyewm): this could be more thorough
|
||||
TEST(FunctionDefsEqualTest, TestFunctionDefsEqual) {
|
||||
// Equal functions
|
||||
FunctionDef fdef1 = test::function::XTimesTwo();
|
||||
FunctionDef fdef2 = test::function::XTimesTwo();
|
||||
EXPECT_TRUE(FunctionDefsEqual(fdef1, fdef2));
|
||||
|
||||
// Different functions
|
||||
fdef2 = test::function::XTimesFour();
|
||||
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
||||
|
||||
// Different signatures
|
||||
fdef2 = test::function::XTimesTwo();
|
||||
fdef2.mutable_signature()->mutable_input_arg(0)->set_name("foo");
|
||||
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
||||
|
||||
// Descriptions must be equal
|
||||
fdef2 = test::function::XTimesTwo();
|
||||
fdef2.mutable_signature()->mutable_input_arg(0)->set_description("foo");
|
||||
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
||||
|
||||
// Different NodeDefs
|
||||
fdef2 = test::function::XTimesTwo();
|
||||
*fdef2.add_node_def() = fdef2.node_def(0);
|
||||
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
||||
|
||||
// Different return values
|
||||
fdef2 = test::function::XTimesTwo();
|
||||
(*fdef2.mutable_ret())["y"] = "y:z:1"; // originally is "y:z:0"
|
||||
EXPECT_FALSE(FunctionDefsEqual(fdef1, fdef2));
|
||||
}
|
||||
|
||||
} // end namespace tensorflow
|
||||
|
@ -360,6 +360,45 @@ void Graph::RemoveEdge(const Edge* e) {
|
||||
free_edges_.push_back(del);
|
||||
}
|
||||
|
||||
Status Graph::AddFunctionLibrary(const FunctionDefLibrary& fdef_lib) {
|
||||
for (const FunctionDef& fdef : fdef_lib.function()) {
|
||||
const FunctionDef* preexisting_fdef = ops_.Find(fdef.signature().name());
|
||||
if (preexisting_fdef != nullptr) {
|
||||
if (!FunctionDefsEqual(*preexisting_fdef, fdef)) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot add function '", fdef.signature().name(),
|
||||
"' because a different function with the same name already "
|
||||
"exists.");
|
||||
}
|
||||
// Ignore duplicate FunctionDefs
|
||||
continue;
|
||||
}
|
||||
// TODO(skyewm): fix test breakages and reenable this check
|
||||
// const OpDef* op_def;
|
||||
// if (ops_.LookUpOpDef(fdef.signature().name(), &op_def).ok()) {
|
||||
// return errors::InvalidArgument(
|
||||
// "Cannot add function '", fdef.signature().name(),
|
||||
// "' because an op with the same name already exists.");
|
||||
// }
|
||||
TF_RETURN_IF_ERROR(ops_.AddFunctionDef(fdef));
|
||||
}
|
||||
for (const GradientDef& grad : fdef_lib.gradient()) {
|
||||
string preexisting_grad_func = ops_.FindGradient(grad.function_name());
|
||||
if (!preexisting_grad_func.empty()) {
|
||||
if (preexisting_grad_func != grad.gradient_func()) {
|
||||
return errors::InvalidArgument(
|
||||
"Cannot assign gradient function '", grad.gradient_func(), "' to '",
|
||||
grad.function_name(), "' because it already has gradient function ",
|
||||
"'", preexisting_grad_func, "'");
|
||||
}
|
||||
// Ignore duplicate GradientDefs
|
||||
continue;
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ops_.AddGradientDef(grad));
|
||||
}
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
namespace {
|
||||
|
||||
void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
|
||||
@ -380,7 +419,8 @@ void Graph::ToGraphDef(GraphDef* graph_def) const {
|
||||
|
||||
void Graph::ToGraphDefSubRange(GraphDef* graph_def, int from_node_id) const {
|
||||
graph_def->Clear();
|
||||
graph_def->mutable_versions()->CopyFrom(versions());
|
||||
*graph_def->mutable_versions() = versions();
|
||||
*graph_def->mutable_library() = ops_.ToProto();
|
||||
std::vector<const Edge*>
|
||||
inputs; // Construct this outside the loop for speed.
|
||||
for (auto id = from_node_id; id < num_node_ids(); ++id) {
|
||||
|
@ -324,6 +324,12 @@ class Graph {
|
||||
// REQUIRES: The edge must exist.
|
||||
void RemoveEdge(const Edge* edge);
|
||||
|
||||
// Adds the function and gradient definitions in `fdef_lib` to this graph's op
|
||||
// registry. Ignores duplicate functions, and returns a bad status if an
|
||||
// imported function differs from an existing function or op with the same
|
||||
// name.
|
||||
Status AddFunctionLibrary(const FunctionDefLibrary& fdef_lib);
|
||||
|
||||
// The number of live nodes in the graph.
|
||||
//
|
||||
// Because nodes can be removed from the graph, num_nodes() is often
|
||||
|
@ -22,6 +22,7 @@ limitations under the License.
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/common_runtime/shape_refiner.h"
|
||||
#include "tensorflow/core/framework/function.h"
|
||||
#include "tensorflow/core/framework/function.pb.h"
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/node_def_util.h"
|
||||
@ -604,6 +605,10 @@ void GraphConstructor::AddPrefixToNodeDef(
|
||||
}
|
||||
|
||||
Status GraphConstructor::Convert() {
|
||||
// Import functions before adding nodes, since imported nodes may refer to
|
||||
// functions
|
||||
TF_RETURN_IF_ERROR(g_->AddFunctionLibrary(gdef_->library()));
|
||||
|
||||
std::vector<InputInfo> inputs;
|
||||
int processed = 0;
|
||||
// Process the NodeDefs in topological order.
|
||||
@ -705,7 +710,12 @@ Status GraphConstructor::Convert() {
|
||||
TF_RETURN_IF_ERROR(MakeEdge(inputs[i].node, inputs[i].index, node, i));
|
||||
}
|
||||
}
|
||||
TF_RETURN_IF_ERROR(ValidateShape(node));
|
||||
|
||||
// TODO(skyewm): remove conditional when b/35715995 ("Functions lack shape
|
||||
// inference") is resolved.
|
||||
if (g_->flib_def().Find(node_def->name()) == nullptr) {
|
||||
TF_RETURN_IF_ERROR(ValidateShape(node));
|
||||
}
|
||||
|
||||
// Update pending_count_ for outputs.
|
||||
for (size_t i = 0; i < outputs_[o].size(); ++i) {
|
||||
@ -847,10 +857,6 @@ Status ImportGraphDef(const ImportGraphDefOptions& opts, const GraphDef& gdef,
|
||||
return_tensors->size(), ")");
|
||||
}
|
||||
}
|
||||
if (gdef.library().function_size() != 0) {
|
||||
return errors::Unimplemented(
|
||||
"Importing GraphDefs containing functions not yet implemented");
|
||||
}
|
||||
return GraphConstructor::Construct(opts, &gdef, g, refiner, return_tensors);
|
||||
}
|
||||
|
||||
|
@ -113,8 +113,6 @@ struct ImportGraphDefOptions {
|
||||
// with ops that are not defined in the binary calling ImportGraphDef.
|
||||
// Similar to the producer_op_list argument to import_graph_def in the
|
||||
// python API.
|
||||
|
||||
// TODO(skyewm): Enable importing functions
|
||||
};
|
||||
|
||||
// Each `return_tensors` entry is the requested node and output index. The index
|
||||
|
@ -31,6 +31,7 @@ limitations under the License.
|
||||
#include "tensorflow/core/platform/logging.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
#include "tensorflow/core/public/session.h"
|
||||
#include "tensorflow/core/public/version.h"
|
||||
|
||||
// TODO(josh11b): Test InitCostModel().
|
||||
@ -2008,30 +2009,196 @@ TEST_F(GraphConstructorTest, ImportGraphDef_ErrorsDoNoChangeTheGraph) {
|
||||
#undef EXPECT_IMPORT_FAILURE
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_ErrorFunctionDefsUnimplemented) {
|
||||
ExpectError(
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_FunctionDefs) {
|
||||
// Import a graph def containing a function. The graph def was generated using
|
||||
// this python code:
|
||||
// @function.Defun(tf.float32, tf.float32, tf.float32)
|
||||
// def FooGrad(x, y, dz): return dz, dz
|
||||
//
|
||||
// @function.Defun(tf.float32, tf.float32, grad_func=FooGrad)
|
||||
// def Foo(x, y): return x + y
|
||||
//
|
||||
// p1 = tf.placeholder(tf.float32)
|
||||
// p2 = tf.placeholder(tf.float32)
|
||||
// foo = Foo(p1, p2)
|
||||
ImportGraphDefOptions opts;
|
||||
ExpectOK(
|
||||
R"EOF(
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "Foo_cc661786"
|
||||
input_arg {
|
||||
name: "x"
|
||||
type: DT_FLOAT
|
||||
node {
|
||||
name: "Placeholder" op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||
attr { key: "shape" value { shape { } } }
|
||||
}
|
||||
output_arg {
|
||||
name: "x"
|
||||
type: DT_FLOAT
|
||||
node {
|
||||
name: "Placeholder_1" op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||
attr { key: "shape" value { shape { } } }
|
||||
}
|
||||
}
|
||||
ret {
|
||||
key: "x"
|
||||
value: "x:0"
|
||||
}
|
||||
}
|
||||
})EOF",
|
||||
ImportGraphDefOptions(),
|
||||
{"Importing GraphDefs containing functions not yet implemented"});
|
||||
node {
|
||||
name: "Foo_d03c39a3" op: "Foo_d03c39a3"
|
||||
input: "Placeholder" input: "Placeholder_1"
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "Foo_d03c39a3"
|
||||
input_arg { name: "x" type: DT_FLOAT }
|
||||
input_arg { name: "y" type: DT_FLOAT }
|
||||
output_arg { name: "add" type: DT_FLOAT }
|
||||
}
|
||||
node_def {
|
||||
name: "add" op: "Add" input: "x" input: "y"
|
||||
attr { key: "T" value { type: DT_FLOAT } }
|
||||
}
|
||||
ret { key: "add" value: "add:z:0" }
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "FooGrad_dc60abc8"
|
||||
input_arg { name: "x" type: DT_FLOAT }
|
||||
input_arg { name: "y" type: DT_FLOAT }
|
||||
input_arg { name: "dz" type: DT_FLOAT }
|
||||
output_arg { name: "dz" type: DT_FLOAT }
|
||||
output_arg { name: "dz_U0" type: DT_FLOAT }
|
||||
}
|
||||
ret { key: "dz" value: "dz:0" }
|
||||
ret { key: "dz_U0" value: "dz:0" }
|
||||
}
|
||||
gradient {
|
||||
function_name: "Foo_d03c39a3" gradient_func: "FooGrad_dc60abc8"
|
||||
}
|
||||
}
|
||||
versions { producer: 21 min_consumer: 12 }
|
||||
)EOF",
|
||||
opts);
|
||||
|
||||
EXPECT_TRUE(HasNode("Placeholder"));
|
||||
EXPECT_TRUE(HasNode("Placeholder_1"));
|
||||
EXPECT_TRUE(HasNode("Foo_d03c39a3"));
|
||||
// Check that Foo and FooGrad have been imported
|
||||
const OpDef* op_def;
|
||||
TF_ASSERT_OK(graph_.op_registry()->LookUpOpDef("Foo_d03c39a3", &op_def));
|
||||
TF_ASSERT_OK(graph_.op_registry()->LookUpOpDef("FooGrad_dc60abc8", &op_def));
|
||||
|
||||
// Re-serialize and run the graph. This tests that re-serialized functions can
|
||||
// be imported again and that imported functions can be run.
|
||||
GraphDef gdef;
|
||||
graph_.ToGraphDef(&gdef);
|
||||
EXPECT_EQ(gdef.library().function_size(), 2);
|
||||
EXPECT_EQ(gdef.library().gradient_size(), 1);
|
||||
EXPECT_EQ(gdef.library().gradient()[0].function_name(), "Foo_d03c39a3");
|
||||
EXPECT_EQ(gdef.library().gradient()[0].gradient_func(), "FooGrad_dc60abc8");
|
||||
|
||||
std::unique_ptr<Session> sess(NewSession(SessionOptions()));
|
||||
TF_ASSERT_OK(sess->Create(gdef));
|
||||
|
||||
Tensor p1(DT_FLOAT, TensorShape({1}));
|
||||
p1.scalar<float>()() = 1.0;
|
||||
Tensor p2(DT_FLOAT, TensorShape({1}));
|
||||
p2.scalar<float>()() = 2.0;
|
||||
std::vector<std::pair<string, Tensor>> inputs = {{"Placeholder", p1},
|
||||
{"Placeholder_1", p2}};
|
||||
std::vector<string> output_names = {"Foo_d03c39a3"};
|
||||
std::vector<string> target_names;
|
||||
std::vector<Tensor> outputs;
|
||||
TF_ASSERT_OK(sess->Run(inputs, output_names, target_names, &outputs));
|
||||
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
EXPECT_EQ(outputs[0].scalar<float>()(), 3.0);
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, ImportGraphDef_NestedFunctionDefs) {
|
||||
// Import a graph def containing a function. The graph def was generated using
|
||||
// this python code:
|
||||
// @function.Defun(tf.float32, tf.float32)
|
||||
// def Inner(x, y): return x + y
|
||||
//
|
||||
// @function.Defun(tf.float32, tf.float32)
|
||||
// def Outer(x, y): return Inner(x, y)
|
||||
//
|
||||
// p1 = tf.placeholder(tf.float32)
|
||||
// p2 = tf.placeholder(tf.float32)
|
||||
// Outer(p1, p2)
|
||||
ImportGraphDefOptions opts;
|
||||
ExpectOK(
|
||||
R"EOF(
|
||||
node {
|
||||
name: "Placeholder" op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||
attr { key: "shape" value { shape { } } }
|
||||
}
|
||||
node {
|
||||
name: "Placeholder_1" op: "Placeholder"
|
||||
attr { key: "dtype" value { type: DT_FLOAT } }
|
||||
attr { key: "shape" value { shape { } } }
|
||||
}
|
||||
node {
|
||||
name: "Outer_966fa13d" op: "Outer_966fa13d"
|
||||
input: "Placeholder" input: "Placeholder_1"
|
||||
}
|
||||
library {
|
||||
function {
|
||||
signature {
|
||||
name: "Outer_966fa13d"
|
||||
input_arg { name: "x" type: DT_FLOAT }
|
||||
input_arg { name: "y" type: DT_FLOAT }
|
||||
output_arg { name: "Inner_d03c39a3" type: DT_FLOAT }
|
||||
}
|
||||
node_def {
|
||||
name: "Inner_d03c39a3" op: "Inner_d03c39a3" input: "x" input: "y"
|
||||
}
|
||||
ret { key: "Inner_d03c39a3" value: "Inner_d03c39a3:add:0" }
|
||||
}
|
||||
function {
|
||||
signature {
|
||||
name: "Inner_d03c39a3"
|
||||
input_arg { name: "x" type: DT_FLOAT }
|
||||
input_arg { name: "y" type: DT_FLOAT }
|
||||
output_arg { name: "add" type: DT_FLOAT }
|
||||
}
|
||||
node_def {
|
||||
name: "add" op: "Add" input: "x" input: "y"
|
||||
attr { key: "T" value { type: DT_FLOAT } }
|
||||
}
|
||||
ret { key: "add" value: "add:z:0" }
|
||||
}
|
||||
}
|
||||
versions { producer: 21 min_consumer: 12 }
|
||||
)EOF",
|
||||
opts);
|
||||
|
||||
EXPECT_TRUE(HasNode("Placeholder"));
|
||||
EXPECT_TRUE(HasNode("Placeholder_1"));
|
||||
EXPECT_TRUE(HasNode("Outer_966fa13d"));
|
||||
// Check that Inner and Outer have been imported
|
||||
const OpDef* op_def;
|
||||
Status s = graph_.op_registry()->LookUpOpDef("Inner_d03c39a3", &op_def);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
s = graph_.op_registry()->LookUpOpDef("Outer_966fa13d", &op_def);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
|
||||
// Re-serialize and run the graph. This tests that re-serialized functions can
|
||||
// be imported again and that imported functions can be run.
|
||||
GraphDef gdef;
|
||||
graph_.ToGraphDef(&gdef);
|
||||
std::unique_ptr<Session> sess(NewSession(SessionOptions()));
|
||||
s = sess->Create(gdef);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
|
||||
Tensor p1(DT_FLOAT, TensorShape({1}));
|
||||
p1.scalar<float>()() = 1.0;
|
||||
Tensor p2(DT_FLOAT, TensorShape({1}));
|
||||
p2.scalar<float>()() = 2.0;
|
||||
std::vector<std::pair<string, Tensor>> inputs = {{"Placeholder", p1},
|
||||
{"Placeholder_1", p2}};
|
||||
std::vector<string> output_names = {"Outer_966fa13d"};
|
||||
std::vector<string> target_names;
|
||||
std::vector<Tensor> outputs;
|
||||
s = sess->Run(inputs, output_names, target_names, &outputs);
|
||||
ASSERT_TRUE(s.ok()) << s.error_message();
|
||||
|
||||
ASSERT_EQ(outputs.size(), 1);
|
||||
EXPECT_EQ(outputs[0].scalar<float>()(), 3.0);
|
||||
}
|
||||
|
||||
TEST_F(GraphConstructorTest, CopyGraph) {
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <set>
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/function_testlib.h"
|
||||
#include "tensorflow/core/graph/graph_constructor.h"
|
||||
#include "tensorflow/core/graph/node_builder.h"
|
||||
#include "tensorflow/core/kernels/ops_util.h"
|
||||
@ -387,6 +388,61 @@ TEST_F(GraphTest, InputEdges) {
|
||||
TF_EXPECT_OK(b->input_edges(&edges));
|
||||
}
|
||||
|
||||
TEST_F(GraphTest, AddFunctionLibrary) {
|
||||
// Basic functionality
|
||||
FunctionDefLibrary proto;
|
||||
*proto.add_function() = test::function::XTimesTwo();
|
||||
*proto.add_function() = test::function::XTimesFour();
|
||||
TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
|
||||
EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
|
||||
EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
|
||||
|
||||
// Duplicate functions are ignored
|
||||
TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
|
||||
EXPECT_TRUE(graph_.flib_def().Find("XTimesTwo") != nullptr);
|
||||
EXPECT_TRUE(graph_.flib_def().Find("XTimesFour") != nullptr);
|
||||
|
||||
// Duplicate names corresponding to different functions trigger an error
|
||||
FunctionDefLibrary error_proto = proto;
|
||||
*error_proto.mutable_function(0)->add_node_def() =
|
||||
error_proto.function(0).node_def(0);
|
||||
Status s = graph_.AddFunctionLibrary(error_proto);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"Cannot add function 'XTimesTwo' because a different function with "
|
||||
"the same name already exists.");
|
||||
|
||||
// TODO(skyewm): reenable along with duplicate op check
|
||||
// Function with same name as an existing op triggers an error
|
||||
// error_proto = proto;
|
||||
// error_proto.mutable_function(0)->mutable_signature()->set_name("Add");
|
||||
// s = graph_.AddFunctionLibrary(error_proto);
|
||||
// EXPECT_FALSE(s.ok());
|
||||
// EXPECT_EQ(s.error_message(),
|
||||
// "Cannot add function 'Add' because an op with the same name "
|
||||
// "already exists.");
|
||||
|
||||
// Adding a gradient function to an existing function is ok
|
||||
GradientDef* grad = proto.add_gradient();
|
||||
grad->set_function_name("XTimesTwo");
|
||||
grad->set_gradient_func("Undefined"); // undefined funcs in grads are ok
|
||||
TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
|
||||
EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
|
||||
|
||||
// Duplicate gradients are ignored
|
||||
TF_EXPECT_OK(graph_.AddFunctionLibrary(proto));
|
||||
EXPECT_EQ(graph_.flib_def().FindGradient("XTimesTwo"), "Undefined");
|
||||
|
||||
// Conflicting gradient triggers an error
|
||||
error_proto = proto;
|
||||
error_proto.mutable_gradient(0)->set_gradient_func("Undefined2");
|
||||
s = graph_.AddFunctionLibrary(error_proto);
|
||||
EXPECT_FALSE(s.ok());
|
||||
EXPECT_EQ(s.error_message(),
|
||||
"Cannot assign gradient function 'Undefined2' to 'XTimesTwo' "
|
||||
"because it already has gradient function 'Undefined'");
|
||||
}
|
||||
|
||||
REGISTER_OP("Input").Output("o: float");
|
||||
REGISTER_OP("In2Out1").Input("a: float").Input("b: float").Output("o: float");
|
||||
|
||||
|
@ -255,47 +255,47 @@ static size_t kNodeMergeContextMaxDepth = 10;
|
||||
class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
public:
|
||||
MklLayoutRewritePass() {
|
||||
csinfo_.conv2d = "Conv2D";
|
||||
csinfo_.mklconv2d = "MklConv2D";
|
||||
csinfo_.conv2d = "Conv2D";
|
||||
csinfo_.mklconv2d = "MklConv2D";
|
||||
csinfo_.mklconv2dwithbias = "MklConv2DWithBias";
|
||||
csinfo_.mklconv2dwithbiasbackpropbias = "MklConv2DWithBiasBackpropBias";
|
||||
csinfo_.biasadd = "BiasAdd";
|
||||
csinfo_.matmul = "MatMul";
|
||||
csinfo_.biasaddgrad = "BiasAddGrad";
|
||||
csinfo_.relu = "Relu";
|
||||
csinfo_.relugrad = "ReluGrad";
|
||||
csinfo_.maxpool = "MaxPool";
|
||||
csinfo_.maxpoolgrad = "MaxPoolGrad";
|
||||
csinfo_.avgpool = "AvgPool";
|
||||
csinfo_.avgpoolgrad = "AvgPoolGrad";
|
||||
csinfo_.conv2dgradinput = "Conv2DBackpropInput";
|
||||
csinfo_.conv2dgradfilter = "Conv2DBackpropFilter";
|
||||
csinfo_.biasadd = "BiasAdd";
|
||||
csinfo_.matmul = "MatMul";
|
||||
csinfo_.biasaddgrad = "BiasAddGrad";
|
||||
csinfo_.relu = "Relu";
|
||||
csinfo_.relugrad = "ReluGrad";
|
||||
csinfo_.maxpool = "MaxPool";
|
||||
csinfo_.maxpoolgrad = "MaxPoolGrad";
|
||||
csinfo_.avgpool = "AvgPool";
|
||||
csinfo_.avgpoolgrad = "AvgPoolGrad";
|
||||
csinfo_.conv2dgradinput = "Conv2DBackpropInput";
|
||||
csinfo_.conv2dgradfilter = "Conv2DBackpropFilter";
|
||||
|
||||
rinfo_.push_back({csinfo_.conv2d, csinfo_.mklconv2d,
|
||||
2, CopyAttrsConv2D, AlwaysRewrite});
|
||||
rinfo_.push_back(
|
||||
{csinfo_.conv2d, csinfo_.mklconv2d, 2, CopyAttrsConv2D, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.conv2dgradfilter,
|
||||
GetMklOpName(csinfo_.conv2dgradfilter),
|
||||
3, CopyAttrsConv2D, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.conv2dgradfilter), 3,
|
||||
CopyAttrsConv2D, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.conv2dgradinput,
|
||||
GetMklOpName(csinfo_.conv2dgradinput),
|
||||
3, CopyAttrsConv2D, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu),
|
||||
1, CopyAttrsRelu, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.maxpool, GetMklOpName(csinfo_.maxpool),
|
||||
1, CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.maxpoolgrad, GetMklOpName(csinfo_.maxpoolgrad),
|
||||
3, CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.avgpool, GetMklOpName(csinfo_.avgpool),
|
||||
1, CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.avgpoolgrad, GetMklOpName(csinfo_.avgpoolgrad),
|
||||
2, CopyAttrsPooling, AlwaysRewrite});
|
||||
GetMklOpName(csinfo_.conv2dgradinput), 3, CopyAttrsConv2D,
|
||||
AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.relu, GetMklOpName(csinfo_.relu), 1,
|
||||
CopyAttrsRelu, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.maxpool, GetMklOpName(csinfo_.maxpool), 1,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.maxpoolgrad, GetMklOpName(csinfo_.maxpoolgrad), 3,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.avgpool, GetMklOpName(csinfo_.avgpool), 1,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
rinfo_.push_back({csinfo_.avgpoolgrad, GetMklOpName(csinfo_.avgpoolgrad), 2,
|
||||
CopyAttrsPooling, AlwaysRewrite});
|
||||
|
||||
// Add info about which ops to add workspace edge to and the slots.
|
||||
wsinfo_.push_back({csinfo_.maxpool, csinfo_.maxpoolgrad, 0, 1, 2, 6});
|
||||
|
||||
// Add a rule for merging nodes
|
||||
minfo_.push_back({csinfo_.mklconv2d, csinfo_.biasadd, 0,
|
||||
csinfo_.mklconv2dwithbias});
|
||||
minfo_.push_back(
|
||||
{csinfo_.mklconv2d, csinfo_.biasadd, 0, csinfo_.mklconv2dwithbias});
|
||||
|
||||
// We use maxhop of 10 based on empirical observations. Also, these are
|
||||
// maxhops in backward data-flow graph. Since input of forward nodes
|
||||
@ -322,13 +322,13 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
/// the number of inputs to the original op, and the function to be used
|
||||
/// to copy attributes for the op
|
||||
typedef struct {
|
||||
string name; // Original name of the op in the graph
|
||||
string newname; // New name of op in the graph
|
||||
int numins; // Number of inputs to the original op
|
||||
string name; // Original name of the op in the graph
|
||||
string newname; // New name of op in the graph
|
||||
int numins; // Number of inputs to the original op
|
||||
// Function handler to copy attributes from old node to new node.
|
||||
std::function<void(const Node*, NodeBuilder*)> copyattrs;
|
||||
std::function<bool(const Node*)> rewriterule; // Rule under which to
|
||||
// rewrite this node.
|
||||
// rewrite this node.
|
||||
} RewriteInfo;
|
||||
|
||||
/// Structure to specify forward op, backward op, and the slot numbers
|
||||
@ -348,18 +348,18 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
|
||||
/// Structure to specify information used in node merge
|
||||
typedef struct {
|
||||
string pred; // Predecessor node string
|
||||
string succ; // Successor node string
|
||||
int op; // What operand no the predecessor node corresponds
|
||||
// to successor node?
|
||||
string pred; // Predecessor node string
|
||||
string succ; // Successor node string
|
||||
int op; // What operand no the predecessor node corresponds
|
||||
// to successor node?
|
||||
string newnode; // Name of the node after merge
|
||||
} MergeInfo;
|
||||
|
||||
/// Structure to specify the context information used in node rewrite rule
|
||||
typedef struct {
|
||||
string node; // Name of the node to be rewritten
|
||||
string fwd; // Node name in forward pass that this node
|
||||
// corresponds to
|
||||
string node; // Name of the node to be rewritten
|
||||
string fwd; // Node name in forward pass that this node
|
||||
// corresponds to
|
||||
size_t maxhop; // Maximum number of hops the fwd is located
|
||||
// from this node. If fwd is farther than maxhop
|
||||
// then we do not rewrite the node.
|
||||
@ -418,9 +418,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
inline void MarkRewrittenNode(Node* n) { visited_nodes_.insert(n); }
|
||||
|
||||
// Clear all visited nodes
|
||||
inline void UnMarkRewrittenNodes() {
|
||||
visited_nodes_.clear();
|
||||
}
|
||||
inline void UnMarkRewrittenNodes() { visited_nodes_.clear(); }
|
||||
|
||||
// Get the name of Mkl op from original TensorFlow op
|
||||
// We prefix 'Mkl' to the original op to get Mkl op.
|
||||
@ -455,7 +453,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// We check for 2 scenarios for rewrite.
|
||||
//
|
||||
// @return RewriteInfo* for the applicable rewrite rule
|
||||
const RewriteInfo* CheckForNodeRewrite(const Node *n) const;
|
||||
const RewriteInfo* CheckForNodeRewrite(const Node* n) const;
|
||||
|
||||
// Default rewrite rule to be used in scenario 1 for rewrite.
|
||||
// @return - true (since we want to always rewrite)
|
||||
@ -512,7 +510,7 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
// NodeBuilder 'nb' for the new node provided. If 'orign' does not dictate
|
||||
// adding workspace edge then do not add it.
|
||||
void AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g, Node* orign,
|
||||
NodeBuilder* nb);
|
||||
NodeBuilder* nb);
|
||||
|
||||
// Functions specific to operators to copy attributes
|
||||
// We need operator-specific function to copy attributes because the framework
|
||||
@ -528,10 +526,9 @@ class MklLayoutRewritePass : public GraphOptimizationPass {
|
||||
void GetDummyMklTensorNode(std::unique_ptr<Graph>* g, Node** out,
|
||||
Node* orign);
|
||||
void GetDummyWorkspaceTensorNode(std::unique_ptr<Graph>* g, Node** out,
|
||||
Node* orign);
|
||||
Node* orign);
|
||||
};
|
||||
|
||||
|
||||
std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
|
||||
|
||||
// We register Mkl rewrite pass for phase 1 in pre-placement group.
|
||||
@ -539,7 +536,6 @@ std::vector<MklLayoutRewritePass::ContextInfo> MklLayoutRewritePass::cinfo_;
|
||||
REGISTER_OPTIMIZATION(OptimizationPassRegistry::PRE_PLACEMENT, 1,
|
||||
MklLayoutRewritePass);
|
||||
|
||||
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
// Helper functions for creating new node
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
@ -578,13 +574,14 @@ void MklLayoutRewritePass::GetDummyMklTensorNode(std::unique_ptr<Graph>* g,
|
||||
8);
|
||||
TensorShape dummy_shape({8});
|
||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orign->def().device()) // We place this node on same
|
||||
// device as device of original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orign->def().device()) // We place this node on same
|
||||
// device as device of original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
(*out)->set_assigned_device_name(orign->assigned_device_name());
|
||||
}
|
||||
|
||||
@ -653,29 +650,30 @@ void MklLayoutRewritePass::GetDummyWorkspaceTensorNode(
|
||||
TensorProto proto;
|
||||
proto.set_dtype(dt);
|
||||
float zero[1] = {0};
|
||||
proto.set_tensor_content(const_cast<const void*>(
|
||||
static_cast<void*>(&zero)), 4);
|
||||
proto.set_tensor_content(const_cast<const void*>(static_cast<void*>(&zero)),
|
||||
4);
|
||||
TensorShape dummy_shape({1});
|
||||
dummy_shape.AsProto(proto.mutable_tensor_shape());
|
||||
TF_CHECK_OK(NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orign->def().device()) // We place this node on same
|
||||
// device as device of original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
TF_CHECK_OK(
|
||||
NodeBuilder((*g)->NewName("DMT"), "Const")
|
||||
.Attr("value", proto)
|
||||
.Attr("dtype", dt)
|
||||
.Device(orign->def().device()) // We place this node on same
|
||||
// device as device of original
|
||||
// node.
|
||||
.Finalize(&**g, out));
|
||||
(*out)->set_assigned_device_name(orign->assigned_device_name());
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
Node* orign, NodeBuilder* nb) {
|
||||
Node* orign,
|
||||
NodeBuilder* nb) {
|
||||
bool workspace_edge_added = false;
|
||||
DataType T;
|
||||
TF_CHECK_OK(GetNodeAttr(orign->def(), "T", &T));
|
||||
for (auto ws : wsinfo_) {
|
||||
if (orign->type_string() == ws.fwdop &&
|
||||
mkl_layer_registry::IsMklLayer(
|
||||
GetMklOpName(orign->type_string()), T)) {
|
||||
mkl_layer_registry::IsMklLayer(GetMklOpName(orign->type_string()), T)) {
|
||||
// If this op is a fwd op, then we need to check if there is an
|
||||
// edge from this node's fwdslot to bwdop's bwdslot. If there is
|
||||
// an edge, then we just add an attribute on this node for setting
|
||||
@ -701,8 +699,8 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
nb->Attr("workspace_enabled", false);
|
||||
}
|
||||
} else if (orign->type_string() == ws.bwdop &&
|
||||
mkl_layer_registry::IsMklLayer(
|
||||
GetMklOpName(orign->type_string()), T)) {
|
||||
mkl_layer_registry::IsMklLayer(
|
||||
GetMklOpName(orign->type_string()), T)) {
|
||||
// If this op is a bwd op, then we need to add workspace edge and
|
||||
// it's Mkl tensor edge between its corresponding fwd op and this
|
||||
// op. Corresponding fwd op is specified in 'fwdop' field of
|
||||
@ -721,7 +719,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
// Add workspace edge between fwd op and bwd op.
|
||||
nb->Input(e->src(), ws.wsfwdslot);
|
||||
// Add Mkl tensor edge for workspace edge between fwd op and bwd op.
|
||||
nb->Input(e->src(), ws.wsfwdslot+1);
|
||||
nb->Input(e->src(), ws.wsfwdslot + 1);
|
||||
// In terms of input ordering, we add these calls to add Input
|
||||
// here because workspace edge (and its Mkl tensor) is the last
|
||||
// edge in the fwdop and bwdop. So all inputs before workspace
|
||||
@ -740,17 +738,17 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
// workspace_enabled to false.
|
||||
if (!workspace_edge_added) {
|
||||
nb->Attr("workspace_enabled", false);
|
||||
Node* dmt_ws = nullptr; // Dummy tensor for workspace
|
||||
Node* dmt_ws = nullptr; // Dummy tensor for workspace
|
||||
Node* dmt_mkl_ws = nullptr; // Dummy Mkl tensor for workspace
|
||||
GetDummyWorkspaceTensorNode(g, &dmt_ws, orign);
|
||||
GetDummyMklTensorNode(g, &dmt_mkl_ws, orign);
|
||||
CHECK_NOTNULL(dmt_ws);
|
||||
CHECK_NOTNULL(dmt_mkl_ws);
|
||||
nb->Input(dmt_ws, 0); // We add dummy tensor as workspace tensor.
|
||||
nb->Input(dmt_ws, 0); // We add dummy tensor as workspace tensor.
|
||||
nb->Input(dmt_mkl_ws, 0); // We add dummy tensor as Mkl
|
||||
// tensor for workspace tensor.
|
||||
// tensor for workspace tensor.
|
||||
VLOG(1) << "MklLayoutRewritePass: dummy workspace_enabled for "
|
||||
<< orign->type_string();
|
||||
<< orign->type_string();
|
||||
}
|
||||
} else {
|
||||
// If this node does not match any workspace info, then we do not
|
||||
@ -763,8 +761,7 @@ void MklLayoutRewritePass::AddWorkSpaceEdgeIfNeeded(std::unique_ptr<Graph>* g,
|
||||
// Op-specific functions to copy attributes from old node to new node
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign,
|
||||
NodeBuilder* nb) {
|
||||
void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign, NodeBuilder* nb) {
|
||||
DataType T;
|
||||
string data_format;
|
||||
string padding;
|
||||
@ -787,7 +784,7 @@ void MklLayoutRewritePass::CopyAttrsConv2D(const Node* orign,
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign,
|
||||
NodeBuilder* nb) {
|
||||
NodeBuilder* nb) {
|
||||
DataType T;
|
||||
string data_format;
|
||||
std::vector<int32> strides;
|
||||
@ -804,7 +801,7 @@ void MklLayoutRewritePass::CopyAttrsBiasAddGrad(const Node* orign,
|
||||
}
|
||||
|
||||
void MklLayoutRewritePass::CopyAttrsPooling(const Node* orign,
|
||||
NodeBuilder* nb) {
|
||||
NodeBuilder* nb) {
|
||||
DataType T;
|
||||
string data_format;
|
||||
string padding;
|
||||
@ -864,7 +861,7 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
|
||||
FillInputs(a, &a_control_edges, &a_in);
|
||||
|
||||
// Get operand op of the operator
|
||||
Node *b = nullptr;
|
||||
Node* b = nullptr;
|
||||
b = a_in[mi->op].first;
|
||||
if (b == nullptr || (b->type_string() != mi->pred)) {
|
||||
// NOTE: Should the first check be assert?
|
||||
@ -887,8 +884,8 @@ Node* MklLayoutRewritePass::CheckForNodeMerge(const Node* a) const {
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||
Node* succ, Node* pred) {
|
||||
Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g, Node* succ,
|
||||
Node* pred) {
|
||||
CHECK_NOTNULL(succ);
|
||||
CHECK_NOTNULL(pred);
|
||||
|
||||
@ -906,15 +903,14 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "strides", &strides));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "data_format", &data_format_pred));
|
||||
TF_CHECK_OK(GetNodeAttr(succ->def(), "data_format", &data_format_succ));
|
||||
TF_CHECK_OK(GetNodeAttr(pred->def(), "use_cudnn_on_gpu",
|
||||
&use_cudnn_on_gnu));
|
||||
TF_CHECK_OK(
|
||||
GetNodeAttr(pred->def(), "use_cudnn_on_gpu", &use_cudnn_on_gnu));
|
||||
// We check to ensure that data formats of both succ and pred are same.
|
||||
// We expect them to be same, so we can enforce this as assert.
|
||||
// But assert can be too strict, so we enforce this as a check.
|
||||
// If the check fails, then we do not merge two nodes.
|
||||
// We also do same check for devices.
|
||||
if (data_format_pred != data_format_succ ||
|
||||
T_pred != T_succ ||
|
||||
if (data_format_pred != data_format_succ || T_pred != T_succ ||
|
||||
pred->assigned_device_name() != succ->assigned_device_name() ||
|
||||
pred->def().device() != succ->def().device()) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
@ -940,11 +936,11 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||
"Will skip node merge optimization");
|
||||
}
|
||||
|
||||
for (const Edge *e : pred->out_edges()) {
|
||||
for (const Edge* e : pred->out_edges()) {
|
||||
if (e->dst() != succ) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"Conv2D does not feed to BiasAdd."
|
||||
"Will skip node merge optimization");
|
||||
"Conv2D does not feed to BiasAdd."
|
||||
"Will skip node merge optimization");
|
||||
}
|
||||
}
|
||||
|
||||
@ -955,8 +951,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||
// Get operand 1 of add_bias
|
||||
// BiasAdd must have 2 inputs: Conv, bias
|
||||
CHECK_EQ(succ->in_edges().size(), 2);
|
||||
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
|
||||
int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
|
||||
Node* oper3_mkl = nullptr; // Mkl tensor corresponding to oper3
|
||||
int oper3_mkl_slot = 0; // For dummy MKL tensor node, output slot is 0.
|
||||
GetDummyMklTensorNode(g, &oper3_mkl, succ); // Get dummy Mkl tensor node
|
||||
// as BiasAdd does not have Mkl tensor as input.
|
||||
CHECK_NOTNULL(oper3_mkl);
|
||||
@ -997,8 +993,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||
newn->set_assigned_device_name(pred->assigned_device_name());
|
||||
|
||||
VLOG(1) << "MklLayoutRewritePass: Merged old node:" << pred->DebugString()
|
||||
<< ", and node: " << succ->DebugString() << ", into node:"
|
||||
<< newn->DebugString();
|
||||
<< ", and node: " << succ->DebugString()
|
||||
<< ", into node:" << newn->DebugString();
|
||||
|
||||
(*g)->RemoveNode(succ);
|
||||
(*g)->RemoveNode(pred);
|
||||
@ -1015,8 +1011,8 @@ Status MklLayoutRewritePass::MergeNode(std::unique_ptr<Graph>* g,
|
||||
// Helper functions for node rewrite
|
||||
//////////////////////////////////////////////////////////////////////////
|
||||
|
||||
Status MklLayoutRewritePass::RewriteNode(
|
||||
std::unique_ptr<Graph>* g, Node* orign, const RewriteInfo* ri) {
|
||||
Status MklLayoutRewritePass::RewriteNode(std::unique_ptr<Graph>* g, Node* orign,
|
||||
const RewriteInfo* ri) {
|
||||
CHECK_NOTNULL(ri);
|
||||
CHECK_NOTNULL(orign);
|
||||
|
||||
@ -1044,9 +1040,10 @@ Status MklLayoutRewritePass::RewriteNode(
|
||||
if (orig_data_format != ctx_data_format || orig_T != ctx_T ||
|
||||
orign->assigned_device_name() != fwdn->assigned_device_name() ||
|
||||
orign->def().device() != fwdn->def().device()) {
|
||||
return Status(error::Code::INVALID_ARGUMENT,
|
||||
"data_format or T attribute or devices of BiasAddGrad and "
|
||||
"Conv2D do not match. Will skip node rewrite optimization");
|
||||
return Status(
|
||||
error::Code::INVALID_ARGUMENT,
|
||||
"data_format or T attribute or devices of BiasAddGrad and "
|
||||
"Conv2D do not match. Will skip node rewrite optimization");
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -1077,7 +1074,7 @@ Status MklLayoutRewritePass::RewriteNode(
|
||||
ri->copyattrs(fwdn, &nb);
|
||||
} else {
|
||||
return Status(error::Code::UNIMPLEMENTED,
|
||||
"Unimplemented case for node rewrite optimization.");
|
||||
"Unimplemented case for node rewrite optimization.");
|
||||
}
|
||||
} else {
|
||||
ri->copyattrs(const_cast<const Node*>(orign), &nb);
|
||||
@ -1106,8 +1103,8 @@ Status MklLayoutRewritePass::RewriteNode(
|
||||
if (e->src_output() < 0) {
|
||||
(*g)->AddEdge(newn, e->src_output(), e->dst(), e->dst_input());
|
||||
} else {
|
||||
(*g)->AddEdge(newn, GetTensorDataIndex(e->src_output()),
|
||||
e->dst(), e->dst_input());
|
||||
(*g)->AddEdge(newn, GetTensorDataIndex(e->src_output()), e->dst(),
|
||||
e->dst_input());
|
||||
}
|
||||
}
|
||||
|
||||
@ -1123,8 +1120,7 @@ Status MklLayoutRewritePass::RewriteNode(
|
||||
}
|
||||
|
||||
const MklLayoutRewritePass::ContextInfo*
|
||||
MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
||||
const Node** fwdn) {
|
||||
MklLayoutRewritePass::SearchMatchingContext(const Node* n, const Node** fwdn) {
|
||||
CHECK_NOTNULL(n);
|
||||
CHECK_NOTNULL(fwdn);
|
||||
*fwdn = nullptr;
|
||||
@ -1144,8 +1140,8 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
||||
return nullptr;
|
||||
}
|
||||
|
||||
VLOG(1) << "MklLayoutRewritePass: Searching graph for: "
|
||||
<< n->type_string() << " in backwards.";
|
||||
VLOG(1) << "MklLayoutRewritePass: Searching graph for: " << n->type_string()
|
||||
<< " in backwards.";
|
||||
|
||||
// Now we will check for forward op name for context info in data
|
||||
// flow graph. Get the max hops we should search for the fwd node.
|
||||
@ -1164,13 +1160,12 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
||||
nqueue.pop();
|
||||
|
||||
std::set<const Node*> visited_nodes;
|
||||
curr_node = curr_pair.first;
|
||||
curr_node = curr_pair.first;
|
||||
curr_depth = curr_pair.second;
|
||||
CHECK_NOTNULL(curr_node);
|
||||
|
||||
VLOG(1) << "MklLayoutRewritePass: Visiting node: "
|
||||
<< curr_node->type_string()
|
||||
<< " at depth: " << curr_depth
|
||||
<< curr_node->type_string() << " at depth: " << curr_depth
|
||||
<< " for node: " << n->type_string();
|
||||
|
||||
// If we find a match, we return immediately.
|
||||
@ -1186,9 +1181,9 @@ MklLayoutRewritePass::SearchMatchingContext(const Node* n,
|
||||
for (const Edge* e : curr_node->in_edges()) {
|
||||
// We do not visit already visited node.
|
||||
if (visited_nodes.find(e->src()) == visited_nodes.end()) {
|
||||
// Depth of these nodes is 1 more than the depth of current node.
|
||||
nqueue.push(std::make_pair(e->src(), curr_depth+1));
|
||||
visited_nodes.insert(e->src());
|
||||
// Depth of these nodes is 1 more than the depth of current node.
|
||||
nqueue.push(std::make_pair(e->src(), curr_depth + 1));
|
||||
visited_nodes.insert(e->src());
|
||||
}
|
||||
}
|
||||
} /* while */
|
||||
@ -1202,8 +1197,7 @@ bool MklLayoutRewritePass::ContextMatchRewrite(const Node* n) {
|
||||
}
|
||||
|
||||
const MklLayoutRewritePass::RewriteInfo*
|
||||
MklLayoutRewritePass::CheckForNodeRewrite(
|
||||
const Node *n) const {
|
||||
MklLayoutRewritePass::CheckForNodeRewrite(const Node* n) const {
|
||||
CHECK_NOTNULL(n);
|
||||
|
||||
// First check if node along with its type is supported by MKL layer.
|
||||
@ -1238,8 +1232,7 @@ MklLayoutRewritePass::CheckForNodeRewrite(
|
||||
// Run function for the pass
|
||||
///////////////////////////////////////////////////////////////////////////////
|
||||
|
||||
bool MklLayoutRewritePass::RunPass(
|
||||
std::unique_ptr<Graph>* g) {
|
||||
bool MklLayoutRewritePass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
bool result = false;
|
||||
CHECK_NOTNULL(g);
|
||||
|
||||
@ -1265,22 +1258,21 @@ bool MklLayoutRewritePass::RunPass(
|
||||
<< " layout optimization.";
|
||||
|
||||
if (RewriteNode(g, n, ri) == Status::OK()) {
|
||||
VLOG(1) << "MklLayoutRewritePass: rewrote node "
|
||||
<< node_name << " with op " << op_name
|
||||
<< " for Mkl layout optimization.";
|
||||
result = true;
|
||||
VLOG(1) << "MklLayoutRewritePass: rewrote node " << node_name
|
||||
<< " with op " << op_name << " for Mkl layout optimization.";
|
||||
result = true;
|
||||
}
|
||||
} else if ((predn = CheckForNodeMerge(n)) != nullptr) {
|
||||
// Otherwise, we will check if the node is to be merged.
|
||||
string n1_name = n->name();
|
||||
string n2_name = predn->name();
|
||||
|
||||
VLOG(1) << "MklLayoutRewritePass: Scheduled nodes "
|
||||
<< n1_name << " and " << n2_name << " for merging";
|
||||
VLOG(1) << "MklLayoutRewritePass: Scheduled nodes " << n1_name << " and "
|
||||
<< n2_name << " for merging";
|
||||
|
||||
if (MergeNode(g, n, predn) == Status::OK()) {
|
||||
VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name
|
||||
<< " and " << n2_name;
|
||||
VLOG(1) << "MklLayoutRewritePass: Merged nodes " << n1_name << " and "
|
||||
<< n2_name;
|
||||
result = true;
|
||||
}
|
||||
}
|
||||
|
@ -18,9 +18,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/graph/mkl_layout_pass.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
@ -112,8 +112,7 @@ class MklLayoutPassTest : public ::testing::Test {
|
||||
REGISTER_OP("Input").Output("o: float").SetIsStateful();
|
||||
REGISTER_OP("HalfInput").Output("o: half").SetIsStateful();
|
||||
REGISTER_OP("MklInput").Output("o: uint8").SetIsStateful();
|
||||
REGISTER_OP("MklInput2").Output("o: uint8")
|
||||
.Output("o1: uint8").SetIsStateful();
|
||||
REGISTER_OP("MklInput2").Output("o: uint8").Output("o1: uint8").SetIsStateful();
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
// Unit tests related to node merge optiimization
|
||||
@ -240,7 +239,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DWithBias_Negative_NoAddBias) {
|
||||
" input: ['A', 'M', 'B', 'N']}");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MklConv2D);M(MklInput);N(MklInput)|"
|
||||
"A->C;B->C:2;M->C:1;N->C:3");
|
||||
"A->C;B->C:2;M->C:1;N->C:3");
|
||||
}
|
||||
|
||||
// MklConv2D output does not go to BiasAdd.
|
||||
@ -372,7 +371,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D) {
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No Conv2D in the context for BiasAddGrad, but MatMul in context.
|
||||
@ -396,7 +395,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_Conv2DBackprop_Negative_NoConv2D_MatMul) {
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// Test set 3: MatMul..BiasAddGrad -> BiasAddGrad rewrite tests
|
||||
@ -419,7 +418,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Positive) {
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(MatMul);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
// No MatMul in the context for BiasAddGrad. No rewrite should happen.
|
||||
@ -440,7 +439,7 @@ TEST_F(MklLayoutPassTest, NodeMerge_MatMulBiasAddGrad_Negative_NoMatMul) {
|
||||
" input: ['D'] }");
|
||||
EXPECT_EQ(DoMklLayoutOptimizationPass(),
|
||||
"A(Input);B(Input);C(Add);D(Sub);E(BiasAddGrad)|"
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
"A->C;A->D:1;B->C:1;C->D;D->E");
|
||||
}
|
||||
|
||||
/////////////////////////////////////////////////////////////////////
|
||||
|
@ -212,7 +212,7 @@ bool MklToTfConversionPass::RunPass(std::unique_ptr<Graph>* g) {
|
||||
// Check if src with is Mkl-compliant, while dst is not Mkl-compliant.
|
||||
|
||||
if (IsMklSupportedOp(src->type_string(), src_datatype) &&
|
||||
!IsMklSupportedOp(dst->type_string(), dst_datatype)) {
|
||||
!IsMklSupportedOp(dst->type_string(), dst_datatype)) {
|
||||
VLOG(1) << "MklToTfConversionPass: Scheduled nodes " << src->name()
|
||||
<< " and " << dst->name() << " for inserting conversion nodes";
|
||||
candidate_edges.push_back(const_cast<Edge*>(e));
|
||||
|
@ -17,9 +17,9 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/graph/mkl_tfconversion_pass.h"
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <algorithm>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "tensorflow/core/framework/op.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
|
@ -17,6 +17,7 @@ filegroup(
|
||||
srcs = glob(
|
||||
[
|
||||
"*_optimizer.*",
|
||||
"constant_folding.*",
|
||||
"model_pruner.*",
|
||||
"graph_rewriter.*",
|
||||
],
|
||||
@ -117,6 +118,37 @@ cc_test(
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "memory_optimizer",
|
||||
srcs = ["memory_optimizer.cc"],
|
||||
hdrs = [
|
||||
"memory_optimizer.h",
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":graph_optimizer",
|
||||
":graph_rewriter",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
],
|
||||
)
|
||||
|
||||
cc_test(
|
||||
name = "memory_optimizer_test",
|
||||
srcs = ["memory_optimizer_test.cc"],
|
||||
deps = [
|
||||
":memory_optimizer",
|
||||
"//tensorflow/cc:cc_ops",
|
||||
"//tensorflow/core:protos_all_cc",
|
||||
"//tensorflow/core:test",
|
||||
"//tensorflow/core:test_main",
|
||||
"//tensorflow/core/grappler:grappler_item",
|
||||
"//tensorflow/core/grappler:utils",
|
||||
"//tensorflow/core/grappler/inputs:trivial_test_graph_input_yielder",
|
||||
],
|
||||
)
|
||||
|
||||
cc_library(
|
||||
name = "layout_optimizer",
|
||||
srcs = ["layout_optimizer.cc"],
|
||||
@ -144,6 +176,7 @@ cc_library(
|
||||
],
|
||||
visibility = ["//visibility:public"],
|
||||
deps = [
|
||||
":constant_folding",
|
||||
":graph_optimizer",
|
||||
":layout_optimizer",
|
||||
":model_pruner",
|
||||
|
@ -72,8 +72,7 @@ class DeviceSimple : public DeviceBase {
|
||||
Tensor* tensor) override {
|
||||
Tensor parsed(tensor_proto.dtype());
|
||||
if (!parsed.FromProto(cpu_allocator(), tensor_proto)) {
|
||||
return errors::InvalidArgument("Cannot parse tensor from proto: ",
|
||||
tensor_proto.DebugString());
|
||||
return errors::InvalidArgument("Cannot parse tensor from tensor_proto.");
|
||||
}
|
||||
*tensor = parsed;
|
||||
return Status::OK();
|
||||
|
83
tensorflow/core/grappler/optimizers/memory_optimizer.cc
Normal file
83
tensorflow/core/grappler/optimizers/memory_optimizer.cc
Normal file
@ -0,0 +1,83 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
#include <unordered_set>
|
||||
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_rewriter.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
std::pair<NodeDef*, NodeDef*> BuildSwapPair(NodeDef* node, int input_to_swap,
|
||||
GraphDef* graph) {
|
||||
string tensor_to_swap = strings::StrCat(node->name(), "_", input_to_swap);
|
||||
|
||||
// Force the tensor to be copied to cpu.
|
||||
NodeDef* swap_out_node = graph->add_node();
|
||||
swap_out_node->set_name(strings::StrCat("swap_out_", tensor_to_swap));
|
||||
swap_out_node->set_op("Identity");
|
||||
swap_out_node->set_device("/CPU");
|
||||
|
||||
// Force the tensor to be restored to the device.
|
||||
NodeDef* swap_in_node = graph->add_node();
|
||||
swap_in_node->set_name(strings::StrCat("swap_in_", tensor_to_swap));
|
||||
swap_in_node->set_op("Identity");
|
||||
*swap_in_node->add_input() = swap_out_node->name();
|
||||
|
||||
// Colocate the swap_in_ node with the node itself.
|
||||
string coloc_group = strings::StrCat("loc@", tensor_to_swap);
|
||||
(*swap_in_node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||
(*node->mutable_attr())["_class"].mutable_list()->add_s(coloc_group);
|
||||
|
||||
return std::make_pair(swap_out_node, swap_in_node);
|
||||
}
|
||||
|
||||
Status MemoryOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
*optimized_graph = item.graph;
|
||||
|
||||
for (auto& node : *optimized_graph->mutable_node()) {
|
||||
if (node.attr().count("swap_to_host") == 0) {
|
||||
continue;
|
||||
}
|
||||
|
||||
// Swap all the tensors that are marked with the 'swap_to_host' attribute.
|
||||
for (int input_id : node.attr().at("swap_to_host").list().i()) {
|
||||
std::pair<NodeDef*, NodeDef*> swap_nodes =
|
||||
BuildSwapPair(&node, input_id, optimized_graph);
|
||||
*swap_nodes.first->add_input() = node.input(input_id);
|
||||
*node.mutable_input(input_id) = swap_nodes.second->name();
|
||||
|
||||
// TODO(bsteiner): Make sure the tensor isn't swapped back in right away
|
||||
// by adding a control dependency to delay the execution of the swap.
|
||||
// string trigger;
|
||||
//*swap_nodes.second->add_input() = strings::StrCat("^", trigger);
|
||||
}
|
||||
}
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
void MemoryOptimizer::Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& optimized_graph, double result) {
|
||||
// Nothing to do for MemoryOptimizer.
|
||||
}
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
42
tensorflow/core/grappler/optimizers/memory_optimizer.h
Normal file
42
tensorflow/core/grappler/optimizers/memory_optimizer.h
Normal file
@ -0,0 +1,42 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifndef TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||
#define TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
// Swap tensors in and out of device memory.
|
||||
class MemoryOptimizer : public GraphOptimizer {
|
||||
public:
|
||||
MemoryOptimizer() {}
|
||||
~MemoryOptimizer() override {}
|
||||
|
||||
string name() const override { return "memory_optimizer"; };
|
||||
|
||||
Status Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* pruned_graph) override;
|
||||
|
||||
void Feedback(Cluster* cluster, const GrapplerItem& item,
|
||||
const GraphDef& pruned_graph, double result) override;
|
||||
};
|
||||
|
||||
} // end namespace grappler
|
||||
} // end namespace tensorflow
|
||||
|
||||
#endif // TENSORFLOW_GRAPPLER_OPTIMIZERS_MEMORY_OPTIMIZER_H_
|
74
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
Normal file
74
tensorflow/core/grappler/optimizers/memory_optimizer_test.cc
Normal file
@ -0,0 +1,74 @@
|
||||
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/memory_optimizer.h"
|
||||
#include "tensorflow/cc/ops/standard_ops.h"
|
||||
#include "tensorflow/core/framework/node_def.pb.h"
|
||||
#include "tensorflow/core/grappler/grappler_item.h"
|
||||
#include "tensorflow/core/grappler/utils.h"
|
||||
#include "tensorflow/core/lib/core/status_test_util.h"
|
||||
#include "tensorflow/core/platform/test.h"
|
||||
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
namespace {
|
||||
|
||||
class MemoryOptimizerTest : public ::testing::Test {};
|
||||
|
||||
TEST_F(MemoryOptimizerTest, SimpleSwapping) {
|
||||
// Build a simple graph with an op that's marked for swapping.
|
||||
tensorflow::Scope s = tensorflow::Scope::NewRootScope();
|
||||
|
||||
Output a = ops::Const(s.WithOpName("a"), 0.0f, {10, 10});
|
||||
Output b = ops::AddN(s.WithOpName("b"), {a});
|
||||
Output c = ops::AddN(s.WithOpName("c"), {b});
|
||||
Output d = ops::AddN(s.WithOpName("d"), {c});
|
||||
Output e = ops::AddN(s.WithOpName("e"), {b, d});
|
||||
|
||||
GrapplerItem item;
|
||||
TF_CHECK_OK(s.ToGraphDef(&item.graph));
|
||||
|
||||
EXPECT_EQ(5, item.graph.node_size());
|
||||
EXPECT_EQ(NodeName(e.name()), item.graph.node(4).name());
|
||||
AttrValue& val =
|
||||
(*item.graph.mutable_node(4)->mutable_attr())["swap_to_host"];
|
||||
val.mutable_list()->add_i(0);
|
||||
|
||||
MemoryOptimizer optimizer;
|
||||
GraphDef output;
|
||||
Status status = optimizer.Optimize(nullptr, item, &output);
|
||||
TF_EXPECT_OK(status);
|
||||
|
||||
EXPECT_EQ(7, output.node_size());
|
||||
const NodeDef& new_e = output.node(4);
|
||||
EXPECT_EQ(NodeName(e.name()), new_e.name());
|
||||
|
||||
EXPECT_EQ(2, new_e.input_size());
|
||||
EXPECT_EQ(NodeName(d.name()), new_e.input(1));
|
||||
EXPECT_EQ("swap_in_e_0", new_e.input(0));
|
||||
|
||||
const NodeDef& swap_out = output.node(5);
|
||||
EXPECT_EQ("swap_out_e_0", swap_out.name());
|
||||
|
||||
const NodeDef& swap_in = output.node(6);
|
||||
EXPECT_EQ("swap_in_e_0", swap_in.name());
|
||||
|
||||
EXPECT_EQ(NodeName(b.name()), swap_out.input(0));
|
||||
EXPECT_EQ(NodeName(swap_out.name()), swap_in.input(0));
|
||||
}
|
||||
|
||||
} // namespace
|
||||
} // namespace grappler
|
||||
} // namespace tensorflow
|
@ -14,6 +14,9 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#include "tensorflow/core/grappler/optimizers/meta_optimizer.h"
|
||||
#include "tensorflow/core/framework/versions.pb.h"
|
||||
#include "tensorflow/core/grappler/optimizers/constant_folding.h"
|
||||
#include "tensorflow/core/grappler/optimizers/graph_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/layout_optimizer.h"
|
||||
#include "tensorflow/core/grappler/optimizers/model_pruner.h"
|
||||
#include "tensorflow/core/lib/core/status.h"
|
||||
@ -21,25 +24,67 @@ limitations under the License.
|
||||
namespace tensorflow {
|
||||
namespace grappler {
|
||||
|
||||
std::unique_ptr<GraphOptimizer> MetaOptimizer::NewOptimizer(
|
||||
const string& optimizer) {
|
||||
VLOG(1) << "Adding graph optimization pass: " << optimizer;
|
||||
std::unique_ptr<GraphOptimizer> graph_optimizer;
|
||||
if (optimizer == "pruning") {
|
||||
graph_optimizer.reset(new ModelPruner());
|
||||
}
|
||||
if (optimizer == "constfold") {
|
||||
graph_optimizer.reset(new ConstantFolding());
|
||||
}
|
||||
if (optimizer == "layout") {
|
||||
graph_optimizer.reset(new LayoutOptimizer());
|
||||
}
|
||||
return graph_optimizer;
|
||||
}
|
||||
|
||||
Status MetaOptimizer::Optimize(Cluster* cluster, const GrapplerItem& item,
|
||||
GraphDef* optimized_graph) {
|
||||
bool already_optimized = false;
|
||||
if (!cfg_.disable_model_pruning()) {
|
||||
already_optimized = true;
|
||||
ModelPruner pruner;
|
||||
TF_RETURN_IF_ERROR(pruner.Optimize(nullptr, item, optimized_graph));
|
||||
std::vector<std::unique_ptr<GraphOptimizer>> optimizers;
|
||||
if (cfg_.optimizers().empty()) {
|
||||
if (!cfg_.disable_model_pruning()) {
|
||||
optimizers.push_back(std::unique_ptr<GraphOptimizer>(new ModelPruner()));
|
||||
}
|
||||
if (cfg_.constant_folding()) {
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new ConstantFolding()));
|
||||
}
|
||||
if (cfg_.optimize_tensor_layout()) {
|
||||
optimizers.push_back(
|
||||
std::unique_ptr<GraphOptimizer>(new LayoutOptimizer()));
|
||||
}
|
||||
} else {
|
||||
std::set<string> avaliable_optimizers = {"pruning", "constfold", "layout"};
|
||||
for (const auto& optimizer : cfg_.optimizers()) {
|
||||
if (avaliable_optimizers.find(optimizer) != avaliable_optimizers.end()) {
|
||||
optimizers.push_back(NewOptimizer(optimizer));
|
||||
}
|
||||
}
|
||||
}
|
||||
if (cfg_.optimize_tensor_layout()) {
|
||||
LayoutOptimizer layout_optimizer;
|
||||
|
||||
if (optimizers.empty()) {
|
||||
*optimized_graph = item.graph;
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
bool already_optimized = false;
|
||||
for (const auto& optimizer : optimizers) {
|
||||
if (!already_optimized) {
|
||||
return layout_optimizer.Optimize(nullptr, item, optimized_graph);
|
||||
TF_RETURN_IF_ERROR(optimizer->Optimize(nullptr, item, optimized_graph));
|
||||
already_optimized = true;
|
||||
} else {
|
||||
GrapplerItem optimized_item = item;
|
||||
optimized_item.graph = *optimized_graph;
|
||||
return layout_optimizer.Optimize(nullptr, optimized_item,
|
||||
optimized_graph);
|
||||
TF_RETURN_IF_ERROR(
|
||||
optimizer->Optimize(nullptr, optimized_item, optimized_graph));
|
||||
}
|
||||
}
|
||||
|
||||
// Copy the graph version.
|
||||
*optimized_graph->mutable_versions() = item.graph.versions();
|
||||
|
||||
return Status::OK();
|
||||
}
|
||||
|
||||
|
@ -39,6 +39,7 @@ class MetaOptimizer : public GraphOptimizer {
|
||||
const GraphDef& optimized_graph, double result) override;
|
||||
|
||||
private:
|
||||
std::unique_ptr<GraphOptimizer> NewOptimizer(const string& optimizer);
|
||||
RewriterConfig cfg_;
|
||||
};
|
||||
|
||||
|
@ -99,25 +99,26 @@ struct LaunchXsmmBackwardFilter {
|
||||
typename TTypes<T, 4>::Tensor kernel,
|
||||
typename TTypes<T, 4>::ConstTensor output_backward,
|
||||
int input_rows, int input_cols, int row_stride,
|
||||
int col_stride, int pad_h, int pad_w, TensorFormat data_format) const {
|
||||
int col_stride, int pad_h, int pad_w,
|
||||
TensorFormat data_format) const {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
template <>
|
||||
struct LaunchXsmmBackwardFilter<CPUDevice, float> {
|
||||
bool operator()(OpKernelContext* context, const CPUDevice& d,
|
||||
typename TTypes<float, 4>::ConstTensor input,
|
||||
typename TTypes<float, 4>::Tensor filter,
|
||||
typename TTypes<float, 4>::ConstTensor output,
|
||||
int input_rows, int input_cols, int row_stride,
|
||||
int col_stride,int pad_h, int pad_w, TensorFormat data_format) const {
|
||||
typename TTypes<float, 4>::ConstTensor output, int input_rows,
|
||||
int input_cols, int row_stride, int col_stride, int pad_h,
|
||||
int pad_w, TensorFormat data_format) const {
|
||||
auto batch = input.dimension(0);
|
||||
auto in_depth = input.dimension(3);
|
||||
auto out_depth = output.dimension(3);
|
||||
auto filter_rows = filter.dimension(0);
|
||||
auto filter_cols = filter.dimension(1);
|
||||
|
||||
|
||||
auto num_threads =
|
||||
context->device()->tensorflow_cpu_worker_threads()->num_threads;
|
||||
// See libxsmm_dnn.h for this struct definition.
|
||||
@ -144,13 +145,11 @@ struct LaunchXsmmBackwardFilter<CPUDevice, float> {
|
||||
desc.fuse_ops = LIBXSMM_DNN_CONV_FUSE_NONE;
|
||||
desc.options = LIBXSMM_DNN_CONV_OPTION_NONE;
|
||||
desc.datatype = LIBXSMM_DNN_DATATYPE_F32;
|
||||
|
||||
|
||||
|
||||
if (!CanUseXsmmConv2D(desc, data_format)) {
|
||||
return false;
|
||||
}
|
||||
|
||||
|
||||
|
||||
auto input_ptr = input.data();
|
||||
auto filter_ptr = filter.data();
|
||||
auto output_ptr = output.data();
|
||||
@ -161,8 +160,6 @@ struct LaunchXsmmBackwardFilter<CPUDevice, float> {
|
||||
};
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
template <typename Device, class T>
|
||||
class Conv2DFastBackpropFilterOp : public OpKernel {
|
||||
public:
|
||||
@ -210,8 +207,7 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, filter_shape, &filter_backprop));
|
||||
|
||||
#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
|
||||
|
||||
#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
|
||||
int64 pad_top, pad_bottom;
|
||||
int64 pad_left, pad_right;
|
||||
OP_REQUIRES_OK(
|
||||
@ -226,22 +222,20 @@ class Conv2DFastBackpropFilterOp : public OpKernel {
|
||||
dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
|
||||
dims.spatial_dims[1].stride, padding_,
|
||||
&dims.spatial_dims[1].output_size, &pad_left, &pad_right));
|
||||
|
||||
if ( pad_left == pad_right && pad_top == pad_bottom ) {
|
||||
|
||||
|
||||
if (pad_left == pad_right && pad_top == pad_bottom) {
|
||||
if (LaunchXsmmBackwardFilter<Device, T>()(
|
||||
context, context->eigen_device<Device>(),
|
||||
input.tensor<T, 4>(),filter_backprop->tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
|
||||
dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride,
|
||||
(int)dims.spatial_dims[1].stride,(int)pad_top, (int)pad_left, data_format_)) {
|
||||
return;
|
||||
context, context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||
filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(),
|
||||
dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
|
||||
static_cast<int>(dims.spatial_dims[0].stride),
|
||||
static_cast<int>(dims.spatial_dims[1].stride),
|
||||
static_cast<int>(pad_top), static_cast<int>(pad_left),
|
||||
data_format_)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
|
||||
|
||||
|
||||
#endif
|
||||
|
||||
functor::SpatialConvolutionBackwardKernel<Device, T>()(
|
||||
context->eigen_device<Device>(), filter_backprop->tensor<T, 4>(),
|
||||
@ -321,19 +315,20 @@ class Conv2DCustomBackpropFilterOp : public OpKernel {
|
||||
dims.spatial_dims[1].input_size, dims.spatial_dims[1].filter_size,
|
||||
dims.spatial_dims[1].stride, padding_,
|
||||
&dims.spatial_dims[1].output_size, &pad_left, &pad_right));
|
||||
#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
|
||||
if ( pad_left == pad_right && pad_top == pad_bottom ) {
|
||||
|
||||
#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
|
||||
if (pad_left == pad_right && pad_top == pad_bottom) {
|
||||
if (LaunchXsmmBackwardFilter<Device, T>()(
|
||||
context, context->eigen_device<Device>(),
|
||||
input.tensor<T, 4>(),filter_backprop->tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
|
||||
dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride,
|
||||
(int)dims.spatial_dims[1].stride,(int)pad_top, (int)pad_left, data_format_)) {
|
||||
return;
|
||||
context, context->eigen_device<Device>(), input.tensor<T, 4>(),
|
||||
filter_backprop->tensor<T, 4>(), out_backprop.tensor<T, 4>(),
|
||||
dims.spatial_dims[0].input_size, dims.spatial_dims[1].input_size,
|
||||
static_cast<int>(dims.spatial_dims[0].stride),
|
||||
static_cast<int>(dims.spatial_dims[1].stride),
|
||||
static_cast<int>(pad_top), static_cast<int>(pad_left),
|
||||
data_format_)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
#endif
|
||||
#endif
|
||||
|
||||
// The total dimension size of each kernel.
|
||||
const int filter_total_size = dims.spatial_dims[0].filter_size *
|
||||
|
@ -131,7 +131,8 @@ struct LaunchXsmmBackwardInputConvolution {
|
||||
typename TTypes<T, 4>::ConstTensor kernel,
|
||||
typename TTypes<T, 4>::ConstTensor output_backward,
|
||||
int input_rows, int input_cols, int row_stride,
|
||||
int col_stride, int pad_h, int pad_w, TensorFormat data_format) const {
|
||||
int col_stride, int pad_h, int pad_w,
|
||||
TensorFormat data_format) const {
|
||||
return false;
|
||||
}
|
||||
};
|
||||
@ -143,7 +144,8 @@ struct LaunchXsmmBackwardInputConvolution<CPUDevice, float> {
|
||||
typename TTypes<float, 4>::ConstTensor kernel,
|
||||
typename TTypes<float, 4>::ConstTensor output_backward,
|
||||
int input_rows, int input_cols, int row_stride,
|
||||
int col_stride, int pad_h, int pad_w, TensorFormat data_format) const {
|
||||
int col_stride, int pad_h, int pad_w,
|
||||
TensorFormat data_format) const {
|
||||
auto batch = input_backward.dimension(0);
|
||||
auto in_depth = input_backward.dimension(3);
|
||||
auto out_depth = output_backward.dimension(3);
|
||||
@ -251,13 +253,16 @@ class Conv2DFastBackpropInputOp : public OpKernel {
|
||||
dims.spatial_dims[1].stride, padding_,
|
||||
&dims.spatial_dims[1].output_size, &pad_left, &pad_right));
|
||||
|
||||
if ( pad_left == pad_right && pad_top == pad_bottom ) {
|
||||
if (pad_left == pad_right && pad_top == pad_bottom) {
|
||||
if (LaunchXsmmBackwardInputConvolution<Device, T>()(
|
||||
context, context->eigen_device<Device>(),
|
||||
in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
|
||||
dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride,
|
||||
(int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left, data_format_)) {
|
||||
context, context->eigen_device<Device>(),
|
||||
in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
|
||||
dims.spatial_dims[1].input_size,
|
||||
static_cast<int>(dims.spatial_dims[0].stride),
|
||||
static_cast<int>(dims.spatial_dims[1].stride),
|
||||
static_cast<int>(pad_top), static_cast<int>(pad_left),
|
||||
data_format_)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
@ -326,8 +331,8 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
OP_REQUIRES_OK(context,
|
||||
context->allocate_output(0, input_shape, &in_backprop));
|
||||
|
||||
// TODO(andydavis) Consider moving code shared with
|
||||
// Conv2DCustomBackpropFilterOp into a shared helper function.
|
||||
// TODO(andydavis) Consider moving code shared with
|
||||
// Conv2DCustomBackpropFilterOp into a shared helper function.
|
||||
#if defined TENSORFLOW_USE_LIBXSMM && defined TENSORFLOW_USE_LIBXSMM_BACKWARD
|
||||
int64 pad_top, pad_bottom;
|
||||
int64 pad_left, pad_right;
|
||||
@ -344,13 +349,16 @@ class Conv2DCustomBackpropInputOp : public OpKernel {
|
||||
dims.spatial_dims[1].stride, padding_,
|
||||
&dims.spatial_dims[1].output_size, &pad_left, &pad_right));
|
||||
|
||||
if ( pad_left == pad_right && pad_top == pad_bottom ) {
|
||||
if (pad_left == pad_right && pad_top == pad_bottom) {
|
||||
if (LaunchXsmmBackwardInputConvolution<Device, T>()(
|
||||
context, context->eigen_device<Device>(),
|
||||
in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
|
||||
dims.spatial_dims[1].input_size, (int)dims.spatial_dims[0].stride,
|
||||
(int)dims.spatial_dims[1].stride, (int)pad_top, (int)pad_left, data_format_)) {
|
||||
context, context->eigen_device<Device>(),
|
||||
in_backprop->tensor<T, 4>(), filter.tensor<T, 4>(),
|
||||
out_backprop.tensor<T, 4>(), dims.spatial_dims[0].input_size,
|
||||
dims.spatial_dims[1].input_size,
|
||||
static_cast<int>(dims.spatial_dims[0].stride),
|
||||
static_cast<int>(dims.spatial_dims[1].stride),
|
||||
static_cast<int>(pad_top), static_cast<int>(pad_left),
|
||||
data_format_)) {
|
||||
return;
|
||||
}
|
||||
}
|
||||
|
@ -243,11 +243,10 @@ void DnnPooling3dGradOp<T>::Compute(
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_DNN_OPS(T) \
|
||||
template class DnnPooling3dOp<T>; \
|
||||
#define DEFINE_DNN_OPS(T) \
|
||||
template class DnnPooling3dOp<T>; \
|
||||
template class DnnPooling3dGradOp<T>;
|
||||
TF_CALL_float(DEFINE_DNN_OPS)
|
||||
TF_CALL_half(DEFINE_DNN_OPS)
|
||||
TF_CALL_float(DEFINE_DNN_OPS) TF_CALL_half(DEFINE_DNN_OPS)
|
||||
#undef DEFINE_DNN_OPS
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -295,8 +295,8 @@ static void MaxPoolingBackwardCustomKernel(
|
||||
params.tensor_in_rows, params.tensor_in_cols, params.depth,
|
||||
params.out_height, params.out_width, params.window_rows,
|
||||
params.window_cols, params.row_stride, params.col_stride, params.pad_rows,
|
||||
params.pad_cols, out_backprop.flat<T>().data(),
|
||||
output->flat<T>().data(), context->eigen_device<Eigen::GpuDevice>());
|
||||
params.pad_cols, out_backprop.flat<T>().data(), output->flat<T>().data(),
|
||||
context->eigen_device<Eigen::GpuDevice>());
|
||||
}
|
||||
|
||||
template <class T>
|
||||
@ -474,8 +474,7 @@ class MaxPoolingGradGradOp : public OpKernel {
|
||||
// tensor_out_as_matrix with the corresponding values in
|
||||
// top_diff_as_matrix.
|
||||
auto shard = [¶ms, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
|
||||
int64 start, int64 limit) {
|
||||
|
||||
int64 start, int64 limit) {
|
||||
const int32 depth = params.depth;
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
const int32 in_cols = params.tensor_in_cols;
|
||||
@ -1010,38 +1009,34 @@ TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_MAX_POOL_KERNELS);
|
||||
// default Eigen implementation so we are using the custom kernel as the
|
||||
// default. However, you can explicitly invoke the eigen version using
|
||||
// kernel_label_map.
|
||||
#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPool") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label("eigen_tensor"), \
|
||||
MaxPoolingOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
MaxPoolingNoMaskOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPoolWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<int64>("Targmax") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingWithArgmaxOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPoolGradWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Targmax"), \
|
||||
MaxPoolingGradWithArgmaxOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPoolGradGradWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Targmax"), \
|
||||
MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
|
||||
#define REGISTER_GPU_ONLY_POOL_KERNELS(T) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPool") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.Label("eigen_tensor"), \
|
||||
MaxPoolingOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("MaxPool").Device(DEVICE_GPU).TypeConstraint<T>("T"), \
|
||||
MaxPoolingNoMaskOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<int64>("Targmax") \
|
||||
.TypeConstraint<T>("T"), \
|
||||
MaxPoolingWithArgmaxOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Targmax"), \
|
||||
MaxPoolingGradWithArgmaxOp<GPUDevice, T>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MaxPoolGradGradWithArgmax") \
|
||||
.Device(DEVICE_GPU) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.TypeConstraint<int64>("Targmax"), \
|
||||
MaxPoolingGradGradWithArgmaxOp<GPUDevice, T>);
|
||||
TF_CALL_GPU_NUMBER_TYPES(REGISTER_GPU_ONLY_POOL_KERNELS);
|
||||
#undef REGISTER_GPU_ONLY_POOL_KERNELS
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
#endif // GOOGLE_CUDA
|
||||
|
||||
#undef REGISTER_MAX_POOL_KERNELS
|
||||
|
||||
|
@ -333,11 +333,11 @@ namespace functor {
|
||||
|
||||
template <typename T>
|
||||
bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
|
||||
const T* bottom_data, const int batch, const int height,
|
||||
const int width, const int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
|
||||
T* top_data, int64* mask, const Eigen::GpuDevice& d) {
|
||||
const T* bottom_data, const int batch, const int height, const int width,
|
||||
const int channels, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_t, const int pad_l, T* top_data,
|
||||
int64* mask, const Eigen::GpuDevice& d) {
|
||||
const int kThreadsPerBlock = 1024;
|
||||
const int output_size = batch * channels * pooled_height * pooled_width;
|
||||
|
||||
@ -351,14 +351,11 @@ bool MaxPoolForwardWithOptionalArgmax<T>::operator()(
|
||||
|
||||
template <typename T>
|
||||
bool MaxPoolBackwardNoMask<T>::operator()(
|
||||
const T* bottom_data, const int batch,
|
||||
const int height, const int width,
|
||||
const int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_t, const int pad_l,
|
||||
const T* top_diff, T* bottom_diff,
|
||||
const Eigen::GpuDevice& d) {
|
||||
const T* bottom_data, const int batch, const int height, const int width,
|
||||
const int channels, const int pooled_height, const int pooled_width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_t, const int pad_l, const T* top_diff,
|
||||
T* bottom_diff, const Eigen::GpuDevice& d) {
|
||||
const int kThreadsPerBlock = 1024;
|
||||
const int bottom_size = batch * channels * height * width;
|
||||
const int top_size = batch * channels * pooled_height * pooled_width;
|
||||
@ -377,9 +374,8 @@ bool MaxPoolBackwardNoMask<T>::operator()(
|
||||
|
||||
template <typename T>
|
||||
bool MaxPoolBackwardWithArgmax<T>::operator()(
|
||||
const int output_size, const int input_size,
|
||||
const T* top_diff, const int64* mask,
|
||||
const int top_offset, const int bottom_offset,
|
||||
const int output_size, const int input_size, const T* top_diff,
|
||||
const int64* mask, const int top_offset, const int bottom_offset,
|
||||
T* bottom_diff, const Eigen::GpuDevice& d) {
|
||||
const int kThreadsPerBlock = 1024;
|
||||
SetZero<<<(input_size + kThreadsPerBlock - 1) / kThreadsPerBlock,
|
||||
|
@ -36,38 +36,36 @@ template <typename T>
|
||||
struct MaxPoolForwardWithOptionalArgmax {
|
||||
bool operator()(const T* bottom_data, const int batch, const int height,
|
||||
const int width, const int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h, const int kernel_w,
|
||||
const int stride_h, const int stride_w, const int pad_t, const int pad_l,
|
||||
T* top_data, int64* mask, const Eigen::GpuDevice& d);
|
||||
const int pooled_width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_t, const int pad_l, T* top_data, int64* mask,
|
||||
const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
|
||||
template <typename T>
|
||||
struct MaxPoolBackwardWithArgmax {
|
||||
bool operator()(const int output_size, const int input_size,
|
||||
const T* top_diff, const int64* mask,
|
||||
const int top_offset, const int bottom_offset,
|
||||
T* bottom_diff, const Eigen::GpuDevice& d);
|
||||
const T* top_diff, const int64* mask, const int top_offset,
|
||||
const int bottom_offset, T* bottom_diff,
|
||||
const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxPoolBackwardNoMask {
|
||||
bool operator()(const T* bottom_data, const int batch,
|
||||
const int height, const int width,
|
||||
const int channels, const int pooled_height,
|
||||
bool operator()(const T* bottom_data, const int batch, const int height,
|
||||
const int width, const int channels, const int pooled_height,
|
||||
const int pooled_width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_t, const int pad_l,
|
||||
const T* top_diff, T* bottom_diff,
|
||||
const Eigen::GpuDevice& d);
|
||||
const int kernel_w, const int stride_h, const int stride_w,
|
||||
const int pad_t, const int pad_l, const T* top_diff,
|
||||
T* bottom_diff, const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
struct MaxPoolGradBackwardWithArgmax {
|
||||
bool operator()(const int output_size, const int input_size,
|
||||
const T* top_diff, const int64* mask,
|
||||
const int top_offset, const int bottom_offset,
|
||||
T* bottom_diff, const Eigen::GpuDevice& d);
|
||||
const T* top_diff, const int64* mask, const int top_offset,
|
||||
const int bottom_offset, T* bottom_diff,
|
||||
const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
template <typename T>
|
||||
@ -75,12 +73,10 @@ struct MaxPoolGradBackwardNoMask {
|
||||
bool operator()(TensorFormat data_format, const T* bottom_data,
|
||||
const T* output_data, const int batch,
|
||||
const int pooled_height, const int pooled_width,
|
||||
const int channels, const int height,
|
||||
const int width, const int kernel_h,
|
||||
const int kernel_w, const int stride_h,
|
||||
const int channels, const int height, const int width,
|
||||
const int kernel_h, const int kernel_w, const int stride_h,
|
||||
const int stride_w, const int pad_t, const int pad_l,
|
||||
const T* top_diff, T* bottom_diff,
|
||||
const Eigen::GpuDevice& d);
|
||||
const T* top_diff, T* bottom_diff, const Eigen::GpuDevice& d);
|
||||
};
|
||||
|
||||
} // namespace functor
|
||||
|
@ -336,10 +336,11 @@ class MklAvgPoolingGradOp : public OpKernel {
|
||||
if (!outbackprop_in_mkl_format) {
|
||||
// For avgpooling, tensor_in_shape should have 1 dimension, and 4
|
||||
// elements.
|
||||
OP_REQUIRES(context, tensor_in_shape.dims() == 1 &&
|
||||
tensor_in_shape.NumElements() == 4,
|
||||
errors::InvalidArgument("original input shape must be "
|
||||
"1-dimensional and 4 elements"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
tensor_in_shape.dims() == 1 && tensor_in_shape.NumElements() == 4,
|
||||
errors::InvalidArgument("original input shape must be "
|
||||
"1-dimensional and 4 elements"));
|
||||
|
||||
// For avgpooling, out_backprop should have 4 dimensions.
|
||||
OP_REQUIRES(context, out_backprop.dims() == 4,
|
||||
|
@ -38,9 +38,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -37,9 +37,9 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -23,6 +23,8 @@ limitations under the License.
|
||||
#define EIGEN_USE_THREADS
|
||||
#include <algorithm>
|
||||
#include <vector>
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
@ -40,8 +42,6 @@ limitations under the License.
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
#include "tensorflow/core/util/use_cudnn.h"
|
||||
#include "tensorflow/core/util/work_sharder.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
|
@ -107,10 +107,10 @@ class MklConv2DOp : public OpKernel {
|
||||
const int64 input_depth =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'C')
|
||||
: GetTensorDim(input, data_format_, 'C');
|
||||
OP_REQUIRES(
|
||||
context, input_depth == filter.dim_size(2),
|
||||
errors::InvalidArgument("input and filter must have the same depth: ",
|
||||
input_depth, " vs ", filter.dim_size(2)));
|
||||
OP_REQUIRES(context, input_depth == filter.dim_size(2),
|
||||
errors::InvalidArgument(
|
||||
"input and filter must have the same depth: ", input_depth,
|
||||
" vs ", filter.dim_size(2)));
|
||||
// The last dimension for filter is out_depth.
|
||||
const int out_depth = static_cast<int>(filter.dim_size(3));
|
||||
|
||||
@ -119,9 +119,10 @@ class MklConv2DOp : public OpKernel {
|
||||
const int64 input_rows_raw =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'H')
|
||||
: GetTensorDim(input, data_format_, 'H');
|
||||
OP_REQUIRES(context, FastBoundsCheck(input_rows_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input rows too large"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input_rows_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input rows too large"));
|
||||
const int input_rows = static_cast<int>(input_rows_raw);
|
||||
const int filter_rows = static_cast<int>(filter.dim_size(0));
|
||||
|
||||
@ -130,9 +131,10 @@ class MklConv2DOp : public OpKernel {
|
||||
const int64 input_cols_raw =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'W')
|
||||
: GetTensorDim(input, data_format_, 'W');
|
||||
OP_REQUIRES(context, FastBoundsCheck(input_cols_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input cols too large"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input_cols_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("Input cols too large"));
|
||||
const int input_cols = static_cast<int>(input_cols_raw);
|
||||
const int filter_cols = static_cast<int>(filter.dim_size(1));
|
||||
|
||||
@ -140,9 +142,10 @@ class MklConv2DOp : public OpKernel {
|
||||
const int64 input_batch_raw =
|
||||
input_in_mkl_format ? GetMklTensorDim(mkl_context.input_shape, 'N')
|
||||
: GetTensorDim(input, data_format_, 'N');
|
||||
OP_REQUIRES(context, FastBoundsCheck(input_batch_raw,
|
||||
std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("batch is too large"));
|
||||
OP_REQUIRES(
|
||||
context,
|
||||
FastBoundsCheck(input_batch_raw, std::numeric_limits<int>::max()),
|
||||
errors::InvalidArgument("batch is too large"));
|
||||
const int batch = static_cast<int>(input_batch_raw);
|
||||
|
||||
// For now we take the stride from the second and third dimensions only (we
|
||||
|
@ -393,18 +393,19 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
if (workspace_enabled == false) {
|
||||
if (convert_input != nullptr) {
|
||||
if (input_in_mkl_format == false) {
|
||||
CHECK_EQ(
|
||||
dnnConversionExecute_F32(
|
||||
convert_input, const_cast<void*>(static_cast<const void*>(
|
||||
tensor_in.flat<T>().data())),
|
||||
input_buf),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnConversionExecute_F32(
|
||||
convert_input,
|
||||
const_cast<void*>(static_cast<const void*>(
|
||||
tensor_in.flat<T>().data())),
|
||||
input_buf),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnDelete_F32(convert_input), E_SUCCESS);
|
||||
convert_input = nullptr;
|
||||
} else {
|
||||
input_shape.GetConvertedFlatData(
|
||||
lt_input_prim, const_cast<void*>(static_cast<const void*>(
|
||||
tensor_in.flat<T>().data())),
|
||||
lt_input_prim,
|
||||
const_cast<void*>(
|
||||
static_cast<const void*>(tensor_in.flat<T>().data())),
|
||||
input_buf);
|
||||
}
|
||||
pooling_resfwd[dnnResourceSrc] = input_buf;
|
||||
@ -449,8 +450,9 @@ class MklMaxPoolingGradOp : public OpKernel {
|
||||
CHECK_EQ(dnnDelete_F32(convert_outbackprop), E_SUCCESS);
|
||||
} else {
|
||||
output_backprop_shape.GetConvertedFlatData(
|
||||
lt_outbackprop_prim, const_cast<void*>(static_cast<const void*>(
|
||||
out_backprop.flat<T>().data())),
|
||||
lt_outbackprop_prim,
|
||||
const_cast<void*>(
|
||||
static_cast<const void*>(out_backprop.flat<T>().data())),
|
||||
outbackprop_buf);
|
||||
}
|
||||
pooling_res[dnnResourceDiffDst] = outbackprop_buf;
|
||||
|
@ -14,153 +14,137 @@ limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
#ifdef INTEL_MKL
|
||||
#include <vector>
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
#include "tensorflow/core/kernels/mkl_pooling_ops_common.h"
|
||||
#include <vector>
|
||||
#include "tensorflow/core/common_runtime/device.h"
|
||||
#include "tensorflow/core/framework/common_shape_fns.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
// Initialization for TensorFlow format
|
||||
void MklPoolParameters::Init(OpKernelContext* context,
|
||||
const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride,
|
||||
Padding padding,
|
||||
TensorFormat data_format,
|
||||
const TensorShape& tensor_in_shape) {
|
||||
// For maxpooling, tensor_in should have 4 dimensions.
|
||||
OP_REQUIRES(context, tensor_in_shape.dims() == 4,
|
||||
errors::InvalidArgument("tensor_in must be 4-dimensional"));
|
||||
// Initialization for TensorFlow format
|
||||
void MklPoolParameters::Init(OpKernelContext* context,
|
||||
const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride, Padding padding,
|
||||
TensorFormat data_format,
|
||||
const TensorShape& tensor_in_shape) {
|
||||
// For maxpooling, tensor_in should have 4 dimensions.
|
||||
OP_REQUIRES(context, tensor_in_shape.dims() == 4,
|
||||
errors::InvalidArgument("tensor_in must be 4-dimensional"));
|
||||
|
||||
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
|
||||
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
|
||||
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
|
||||
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
|
||||
depth = GetTensorDim(tensor_in_shape, data_format, 'C');
|
||||
tensor_in_cols = GetTensorDim(tensor_in_shape, data_format, 'W');
|
||||
tensor_in_rows = GetTensorDim(tensor_in_shape, data_format, 'H');
|
||||
tensor_in_batch = GetTensorDim(tensor_in_shape, data_format, 'N');
|
||||
|
||||
Init(context, ksize, stride, padding, data_format);
|
||||
}
|
||||
Init(context, ksize, stride, padding, data_format);
|
||||
}
|
||||
|
||||
// Initialization for MKL format
|
||||
void MklPoolParameters::Init(OpKernelContext* context,
|
||||
const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride,
|
||||
Padding padding,
|
||||
TensorFormat data_format,
|
||||
const MklShape* mklInputShape) {
|
||||
// Get the input sizes
|
||||
depth = mklInputShape->GetSizes()[2];
|
||||
tensor_in_cols = mklInputShape->GetSizes()[0];
|
||||
tensor_in_rows = mklInputShape->GetSizes()[1];
|
||||
tensor_in_batch = mklInputShape->GetSizes()[3];
|
||||
// Initialization for MKL format
|
||||
void MklPoolParameters::Init(OpKernelContext* context,
|
||||
const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride, Padding padding,
|
||||
TensorFormat data_format,
|
||||
const MklShape* mklInputShape) {
|
||||
// Get the input sizes
|
||||
depth = mklInputShape->GetSizes()[2];
|
||||
tensor_in_cols = mklInputShape->GetSizes()[0];
|
||||
tensor_in_rows = mklInputShape->GetSizes()[1];
|
||||
tensor_in_batch = mklInputShape->GetSizes()[3];
|
||||
|
||||
Init(context, ksize, stride, padding, data_format);
|
||||
}
|
||||
Init(context, ksize, stride, padding, data_format);
|
||||
}
|
||||
|
||||
// Common Initialization for TensorFlow and MKL formats
|
||||
void MklPoolParameters::Init(OpKernelContext* context,
|
||||
const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride,
|
||||
Padding padding,
|
||||
TensorFormat data_format) {
|
||||
// Get the data format
|
||||
this->data_format = data_format;
|
||||
// Common Initialization for TensorFlow and MKL formats
|
||||
void MklPoolParameters::Init(OpKernelContext* context,
|
||||
const std::vector<int32>& ksize,
|
||||
const std::vector<int32>& stride, Padding padding,
|
||||
TensorFormat data_format) {
|
||||
// Get the data format
|
||||
this->data_format = data_format;
|
||||
|
||||
// Get the output sizes
|
||||
window_rows = GetTensorDim(ksize, data_format, 'H');
|
||||
window_cols = GetTensorDim(ksize, data_format, 'W');
|
||||
depth_window = GetTensorDim(ksize, data_format, 'C');
|
||||
// Get the output sizes
|
||||
window_rows = GetTensorDim(ksize, data_format, 'H');
|
||||
window_cols = GetTensorDim(ksize, data_format, 'W');
|
||||
depth_window = GetTensorDim(ksize, data_format, 'C');
|
||||
|
||||
// Get the strides
|
||||
row_stride = GetTensorDim(stride, data_format, 'H');
|
||||
col_stride = GetTensorDim(stride, data_format, 'W');
|
||||
depth_stride = GetTensorDim(stride, data_format, 'C');
|
||||
// Get the strides
|
||||
row_stride = GetTensorDim(stride, data_format, 'H');
|
||||
col_stride = GetTensorDim(stride, data_format, 'W');
|
||||
depth_stride = GetTensorDim(stride, data_format, 'C');
|
||||
|
||||
// We only support 2D pooling across width/height and depthwise
|
||||
// pooling, not a combination.
|
||||
OP_REQUIRES(context,
|
||||
(depth_window == 1 || (window_rows == 1 && window_cols == 1)),
|
||||
errors::Unimplemented(
|
||||
// We only support 2D pooling across width/height and depthwise
|
||||
// pooling, not a combination.
|
||||
OP_REQUIRES(context,
|
||||
(depth_window == 1 || (window_rows == 1 && window_cols == 1)),
|
||||
errors::Unimplemented(
|
||||
"MaxPooling supports exactly one of pooling across depth "
|
||||
"or pooling across width/height."));
|
||||
|
||||
if (depth_window == 1) {
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSizeVerbose(tensor_in_rows,
|
||||
window_rows,
|
||||
row_stride,
|
||||
padding,
|
||||
&out_height,
|
||||
&pad_top,
|
||||
&pad_bottom));
|
||||
if (depth_window == 1) {
|
||||
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
|
||||
tensor_in_rows, window_rows, row_stride,
|
||||
padding, &out_height, &pad_top, &pad_bottom));
|
||||
|
||||
OP_REQUIRES_OK(context,
|
||||
GetWindowedOutputSizeVerbose(tensor_in_cols,
|
||||
window_cols,
|
||||
col_stride,
|
||||
padding,
|
||||
&out_width,
|
||||
&pad_left,
|
||||
&pad_right));
|
||||
} else {
|
||||
// Our current version of depthwise max pooling does not support
|
||||
// any padding, and expects the depth_window to equal the depth
|
||||
// stride (no overlapping).
|
||||
OP_REQUIRES(context, depth % depth_window == 0,
|
||||
errors::Unimplemented("Depthwise max pooling requires the"
|
||||
" depth window to evenly divide the"
|
||||
" input depth"));
|
||||
OP_REQUIRES(context, depth_stride == depth_window,
|
||||
errors::Unimplemented("Depthwise max pooling requires the"
|
||||
" depth window to equal the depth"
|
||||
" stride"));
|
||||
OP_REQUIRES_OK(context, GetWindowedOutputSizeVerbose(
|
||||
tensor_in_cols, window_cols, col_stride,
|
||||
padding, &out_width, &pad_left, &pad_right));
|
||||
} else {
|
||||
// Our current version of depthwise max pooling does not support
|
||||
// any padding, and expects the depth_window to equal the depth
|
||||
// stride (no overlapping).
|
||||
OP_REQUIRES(context, depth % depth_window == 0,
|
||||
errors::Unimplemented("Depthwise max pooling requires the"
|
||||
" depth window to evenly divide the"
|
||||
" input depth"));
|
||||
OP_REQUIRES(context, depth_stride == depth_window,
|
||||
errors::Unimplemented("Depthwise max pooling requires the"
|
||||
" depth window to equal the depth"
|
||||
" stride"));
|
||||
|
||||
// The current version of depthwise max is only implemented on CPU.
|
||||
OP_REQUIRES(context,
|
||||
(DeviceType(static_cast<Device*>(context->device())
|
||||
->attributes()
|
||||
.device_type()) == DeviceType(DEVICE_CPU)),
|
||||
errors::Unimplemented("Depthwise max pooling is currently "
|
||||
"only implemented for CPU devices."));
|
||||
// The current version of depthwise max is only implemented on CPU.
|
||||
OP_REQUIRES(context,
|
||||
(DeviceType(static_cast<Device*>(context->device())
|
||||
->attributes()
|
||||
.device_type()) == DeviceType(DEVICE_CPU)),
|
||||
errors::Unimplemented("Depthwise max pooling is currently "
|
||||
"only implemented for CPU devices."));
|
||||
|
||||
pad_depth = 0;
|
||||
out_depth = depth / depth_window;
|
||||
}
|
||||
pad_depth = 0;
|
||||
out_depth = depth / depth_window;
|
||||
}
|
||||
}
|
||||
|
||||
// Transfers the right parameters for pooling to the op parameters
|
||||
// Updates context->status if there is an invalid input.
|
||||
void ExtractMklOpParams(OpKernelContext* context,
|
||||
TensorFormat data_format,
|
||||
const MklPoolParameters ¶ms,
|
||||
MklPoolingOpParams *mkl_params) {
|
||||
mkl_params->in_sizes[0] = params.tensor_in_cols;
|
||||
mkl_params->in_sizes[1] = params.tensor_in_rows;
|
||||
mkl_params->in_sizes[2] = params.depth;
|
||||
mkl_params->in_sizes[3] = params.tensor_in_batch;
|
||||
// Transfers the right parameters for pooling to the op parameters
|
||||
// Updates context->status if there is an invalid input.
|
||||
void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format,
|
||||
const MklPoolParameters& params,
|
||||
MklPoolingOpParams* mkl_params) {
|
||||
mkl_params->in_sizes[0] = params.tensor_in_cols;
|
||||
mkl_params->in_sizes[1] = params.tensor_in_rows;
|
||||
mkl_params->in_sizes[2] = params.depth;
|
||||
mkl_params->in_sizes[3] = params.tensor_in_batch;
|
||||
|
||||
GetStridesFromSizes(data_format,
|
||||
mkl_params->in_strides,
|
||||
mkl_params->in_sizes);
|
||||
GetStridesFromSizes(data_format, mkl_params->in_strides,
|
||||
mkl_params->in_sizes);
|
||||
|
||||
mkl_params->out_sizes[0] = params.out_width;
|
||||
mkl_params->out_sizes[1] = params.out_height;
|
||||
mkl_params->out_sizes[2] = params.depth;
|
||||
mkl_params->out_sizes[3] = params.tensor_in_batch;
|
||||
mkl_params->out_sizes[0] = params.out_width;
|
||||
mkl_params->out_sizes[1] = params.out_height;
|
||||
mkl_params->out_sizes[2] = params.depth;
|
||||
mkl_params->out_sizes[3] = params.tensor_in_batch;
|
||||
|
||||
GetStridesFromSizes(data_format,
|
||||
mkl_params->out_strides,
|
||||
mkl_params->out_sizes);
|
||||
GetStridesFromSizes(data_format, mkl_params->out_strides,
|
||||
mkl_params->out_sizes);
|
||||
|
||||
mkl_params->in_offset[0] = -params.pad_left;
|
||||
mkl_params->in_offset[1] = -params.pad_top;
|
||||
mkl_params->in_offset[2] = -params.pad_right;
|
||||
mkl_params->in_offset[3] = -params.pad_bottom;
|
||||
mkl_params->in_offset[0] = -params.pad_left;
|
||||
mkl_params->in_offset[1] = -params.pad_top;
|
||||
mkl_params->in_offset[2] = -params.pad_right;
|
||||
mkl_params->in_offset[3] = -params.pad_bottom;
|
||||
|
||||
mkl_params->kernel_stride[0] = params.col_stride;
|
||||
mkl_params->kernel_stride[1] = params.row_stride;
|
||||
mkl_params->kernel_stride[0] = params.col_stride;
|
||||
mkl_params->kernel_stride[1] = params.row_stride;
|
||||
|
||||
mkl_params->kernel_size[0] = params.window_cols;
|
||||
mkl_params->kernel_size[1] = params.window_rows;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
mkl_params->kernel_size[0] = params.window_cols;
|
||||
mkl_params->kernel_size[1] = params.window_rows;
|
||||
}
|
||||
} // namespace tensorflow
|
||||
#endif // INTEL_MKL
|
||||
|
@ -76,17 +76,16 @@ typedef struct {
|
||||
size_t in_strides[4];
|
||||
size_t out_sizes[4];
|
||||
size_t out_strides[4];
|
||||
int in_offset[4];
|
||||
int in_offset[4];
|
||||
size_t kernel_stride[2];
|
||||
size_t kernel_size[2];
|
||||
} MklPoolingOpParams;
|
||||
|
||||
// Transfers the right parameters for pooling to the op parameters
|
||||
// Updates context->status if there is an invalid input.
|
||||
void ExtractMklOpParams(OpKernelContext* context,
|
||||
TensorFormat data_format,
|
||||
const MklPoolParameters ¶ms,
|
||||
MklPoolingOpParams *mkl_params);
|
||||
void ExtractMklOpParams(OpKernelContext* context, TensorFormat data_format,
|
||||
const MklPoolParameters& params,
|
||||
MklPoolingOpParams* mkl_params);
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
@ -1,397 +1,397 @@
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/nn_ops.cc.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
|
||||
#include "tensorflow/core/platform/default/logging.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
struct MklReluHelpers {
|
||||
static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a) {
|
||||
OP_REQUIRES(context, a.IsSameSize(g),
|
||||
errors::InvalidArgument("g and a must be the same size"));
|
||||
}
|
||||
static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a) {
|
||||
ValidateSameSizeHelper(context, g, a);
|
||||
return context->status().ok();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklReluOp : public OpKernel {
|
||||
public:
|
||||
~MklReluOp() {}
|
||||
|
||||
explicit MklReluOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
MklReluOpContext mkl_context;
|
||||
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
GetMklShape(context, 0, &mkl_context.input_shape);
|
||||
void* user_i = static_cast<void*>(const_cast<T*>(input.flat<T>().data()));
|
||||
bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format && !input.dims()) { // handle the case of a scalar
|
||||
const TensorShape& o_shape = input.shape();
|
||||
Tensor* out_tensor = nullptr;
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &out_tensor, o_shape,
|
||||
mkl_context.output_shape);
|
||||
void* out_o = static_cast<void*>(out_tensor->flat<T>().data());
|
||||
(static_cast<T*>(out_o))[0] =
|
||||
std::max((static_cast<T*>(user_i))[0], static_cast<T>(0));
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate size, stride for input if input is in MKL format.
|
||||
if (input_in_mkl_format) {
|
||||
mkl_context.in_dims = mkl_context.input_shape.GetDimension();
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = mkl_context.input_shape.GetSizes()[i];
|
||||
mkl_context.in_strides[i] = mkl_context.input_shape.GetStrides()[i];
|
||||
}
|
||||
} else {
|
||||
mkl_context.in_dims = input.dims();
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = input.dim_size((mkl_context.in_dims - 1) - i);
|
||||
}
|
||||
mkl_context.in_strides[0] = 1;
|
||||
for (int i = 1; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_strides[i] =
|
||||
mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
float negative_slope = 0.0;
|
||||
mkl_context.MklCreateInputLayouts(context);
|
||||
CHECK_EQ(dnnReLUCreateForward_F32(&mkl_context.prim_relu_fwd, NULL,
|
||||
mkl_context.lt_input, negative_slope),
|
||||
E_SUCCESS);
|
||||
|
||||
Tensor* output = nullptr;
|
||||
|
||||
if (input_in_mkl_format) {
|
||||
TensorShape tf_shape;
|
||||
mkl_context.output_shape.SetMklTensor(true);
|
||||
mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_fwd,
|
||||
dnnResourceDst);
|
||||
mkl_context.output_shape.SetTfLayout(
|
||||
mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
|
||||
mkl_context.output_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap());
|
||||
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_context.output_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output, tf_shape,
|
||||
mkl_context.output_shape);
|
||||
} else {
|
||||
const TensorShape& o_shape = input.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, o_shape,
|
||||
mkl_context.output_shape);
|
||||
}
|
||||
|
||||
void* user_o = static_cast<void*>(const_cast<T*>(output->flat<T>().data()));
|
||||
|
||||
mkl_context.relu_res[dnnResourceDst] = user_o;
|
||||
mkl_context.relu_res[dnnResourceSrc] = user_i;
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_fwd, mkl_context.relu_res),
|
||||
E_SUCCESS);
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
private:
|
||||
typedef struct {
|
||||
int in_dims;
|
||||
size_t* in_sizes;
|
||||
size_t* in_strides;
|
||||
MklShape input_shape, output_shape;
|
||||
dnnPrimitive_t prim_relu_fwd = nullptr;
|
||||
void* relu_res[dnnResourceNumber];
|
||||
dnnLayout_t lt_input = nullptr;
|
||||
|
||||
void MklCleanup() {
|
||||
bool input_in_mkl_format = input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format) {
|
||||
dnnLayoutDelete_F32(lt_input);
|
||||
free(in_sizes);
|
||||
free(in_strides);
|
||||
}
|
||||
dnnDelete_F32(prim_relu_fwd);
|
||||
}
|
||||
|
||||
void MklCreateInputLayouts(OpKernelContext* context) {
|
||||
bool input_in_mkl_format = input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format) {
|
||||
CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
|
||||
}
|
||||
}
|
||||
} MklReluOpContext;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklReluGradOp : public OpKernel {
|
||||
public:
|
||||
~MklReluGradOp() {}
|
||||
|
||||
explicit MklReluGradOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override;
|
||||
|
||||
private:
|
||||
typedef struct {
|
||||
int in_dims;
|
||||
size_t* in_sizes;
|
||||
size_t* in_strides;
|
||||
MklShape input_shape, grad_shape, output_shape;
|
||||
void* relu_res[dnnResourceNumber];
|
||||
dnnPrimitive_t prim_relu_bwd;
|
||||
dnnLayout_t lt_input, lt_grad;
|
||||
|
||||
void MklPrepareReluGradInputs(OpKernelContext* context,
|
||||
Tensor* mkl_tmp_grad_buf_tensor,
|
||||
Tensor* mkl_tmp_input_buf_tensor) {
|
||||
dnnPrimitive_t cv_user_to_reluB_input, cv_user_to_reluB_grad;
|
||||
dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_grad;
|
||||
|
||||
const Tensor& g = MklGetInput(context, 0);
|
||||
const Tensor& a = MklGetInput(context, 1);
|
||||
|
||||
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
|
||||
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
|
||||
&mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst),
|
||||
E_SUCCESS);
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
|
||||
prim_relu_bwd, dnnResourceSrc),
|
||||
E_SUCCESS);
|
||||
|
||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) {
|
||||
AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad,
|
||||
&relu_res[dnnResourceDiffDst]);
|
||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad,
|
||||
mkl_lt_internal_grad),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g,
|
||||
relu_res[dnnResourceDiffDst]),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(cv_user_to_reluB_grad);
|
||||
} else {
|
||||
relu_res[dnnResourceDiffDst] = user_g;
|
||||
}
|
||||
|
||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) {
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
|
||||
&relu_res[dnnResourceSrc]);
|
||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input,
|
||||
mkl_lt_internal_input),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i,
|
||||
relu_res[dnnResourceSrc]),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(cv_user_to_reluB_input);
|
||||
} else {
|
||||
relu_res[dnnResourceSrc] = user_i;
|
||||
}
|
||||
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_input);
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_grad);
|
||||
}
|
||||
|
||||
void MklCreateInputLayouts(OpKernelContext* context) {
|
||||
bool grad_is_mkl = grad_shape.IsMklTensor();
|
||||
bool input_is_mkl = input_shape.IsMklTensor();
|
||||
if (!input_is_mkl) {
|
||||
CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
|
||||
}
|
||||
|
||||
if (!grad_is_mkl) {
|
||||
CHECK_EQ(dnnLayoutCreate_F32(<_grad, in_dims, in_sizes, in_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_grad = static_cast<dnnLayout_t>(grad_shape.GetCurLayout());
|
||||
}
|
||||
}
|
||||
|
||||
void MklCleanup() {
|
||||
bool grad_is_mkl = grad_shape.IsMklTensor();
|
||||
bool input_is_mkl = input_shape.IsMklTensor();
|
||||
dnnDelete_F32(prim_relu_bwd);
|
||||
if (!input_is_mkl) {
|
||||
dnnLayoutDelete_F32(lt_input);
|
||||
free(in_sizes);
|
||||
free(in_strides);
|
||||
}
|
||||
if (!grad_is_mkl) {
|
||||
dnnLayoutDelete_F32(lt_grad);
|
||||
}
|
||||
}
|
||||
} MklReluGradOpContext;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
|
||||
void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
||||
MklReluGradOpContext mkl_context;
|
||||
const Tensor& g = MklGetInput(context, 0);
|
||||
const Tensor& a = MklGetInput(context, 1);
|
||||
|
||||
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
|
||||
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
|
||||
|
||||
GetMklShape(context, 0, &mkl_context.grad_shape);
|
||||
GetMklShape(context, 1, &mkl_context.input_shape);
|
||||
|
||||
bool grad_is_mkl = mkl_context.grad_shape.IsMklTensor();
|
||||
bool input_is_mkl = mkl_context.input_shape.IsMklTensor();
|
||||
if (!input_is_mkl && !grad_is_mkl &&
|
||||
!MklReluHelpers::ValidateSameSize(context, g, a))
|
||||
return;
|
||||
Tensor* output = nullptr;
|
||||
if (!input_is_mkl && !grad_is_mkl &&
|
||||
!a.dims()) { // handle the case of a scalar
|
||||
// Allocate space for g and
|
||||
const TensorShape& g_shape = g.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, g_shape,
|
||||
mkl_context.output_shape);
|
||||
void* out_o = static_cast<void*>(output->flat<T>().data());
|
||||
(static_cast<T*>(out_o))[0] =
|
||||
(static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate size, stride for input if input/grad is in MKL format.
|
||||
if (grad_is_mkl || input_is_mkl) {
|
||||
const MklShape* tmp_mkl_shape =
|
||||
(grad_is_mkl) ? &mkl_context.grad_shape : &mkl_context.input_shape;
|
||||
|
||||
mkl_context.in_dims = tmp_mkl_shape->GetDimension();
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
|
||||
mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
|
||||
}
|
||||
} else {
|
||||
mkl_context.in_dims = g.dims();
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = g.dim_size((mkl_context.in_dims - 1) - i);
|
||||
}
|
||||
mkl_context.in_strides[0] = 1;
|
||||
for (int i = 1; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_strides[i] =
|
||||
mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
mkl_context.MklCreateInputLayouts(context);
|
||||
float negative_slope = 0.0;
|
||||
CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
|
||||
mkl_context.lt_grad, mkl_context.lt_input,
|
||||
negative_slope),
|
||||
E_SUCCESS);
|
||||
Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
|
||||
mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor,
|
||||
&mkl_tmp_input_buf_tensor);
|
||||
|
||||
if (input_is_mkl ||
|
||||
grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/
|
||||
TensorShape tf_shape;
|
||||
mkl_context.output_shape.SetMklTensor(true);
|
||||
mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_bwd,
|
||||
dnnResourceDiffSrc);
|
||||
mkl_context.output_shape.SetTfLayout(
|
||||
mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
|
||||
// If input_is_mkl or grad_is_mkl, then we copy strides and sizes from Mkl
|
||||
// shape of one that is in MKL layout.
|
||||
if (grad_is_mkl == true) {
|
||||
mkl_context.output_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.grad_shape.GetTfToMklDimMap());
|
||||
} else {
|
||||
mkl_context.output_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap());
|
||||
}
|
||||
|
||||
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_context.output_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output, tf_shape,
|
||||
mkl_context.output_shape);
|
||||
|
||||
} else {
|
||||
const TensorShape& o_shape = g.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, o_shape,
|
||||
mkl_context.output_shape);
|
||||
}
|
||||
|
||||
mkl_context.relu_res[dnnResourceDiffSrc] =
|
||||
static_cast<void*>(output->flat<T>().data());
|
||||
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, mkl_context.relu_res),
|
||||
E_SUCCESS);
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
/* Register DNN kernels for supported operations and supported types - right now
|
||||
* it is only Relu and f32*/
|
||||
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklRelu") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
MklReluOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklReluGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
MklReluGradOp<CPUDevice, type>);
|
||||
TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
/* Copyright 2015 The TensorFlow Authors. All Rights Reserved.
|
||||
|
||||
Licensed under the Apache License, Version 2.0 (the "License");
|
||||
you may not use this file except in compliance with the License.
|
||||
You may obtain a copy of the License at
|
||||
|
||||
http://www.apache.org/licenses/LICENSE-2.0
|
||||
|
||||
Unless required by applicable law or agreed to in writing, software
|
||||
distributed under the License is distributed on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
See the License for the specific language governing permissions and
|
||||
limitations under the License.
|
||||
==============================================================================*/
|
||||
|
||||
// See docs in ../ops/nn_ops.cc.
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||
#include "tensorflow/core/framework/numeric_op.h"
|
||||
#include "tensorflow/core/framework/op_kernel.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/framework/tensor.h"
|
||||
#include "tensorflow/core/lib/core/errors.h"
|
||||
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
#include "tensorflow/core/platform/default/logging.h"
|
||||
#include "tensorflow/core/util/mkl_util.h"
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
struct MklReluHelpers {
|
||||
static void ValidateSameSizeHelper(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a) {
|
||||
OP_REQUIRES(context, a.IsSameSize(g),
|
||||
errors::InvalidArgument("g and a must be the same size"));
|
||||
}
|
||||
static bool ValidateSameSize(OpKernelContext* context, const Tensor& g,
|
||||
const Tensor& a) {
|
||||
ValidateSameSizeHelper(context, g, a);
|
||||
return context->status().ok();
|
||||
}
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklReluOp : public OpKernel {
|
||||
public:
|
||||
~MklReluOp() {}
|
||||
|
||||
explicit MklReluOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
MklReluOpContext mkl_context;
|
||||
|
||||
const Tensor& input = MklGetInput(context, 0);
|
||||
GetMklShape(context, 0, &mkl_context.input_shape);
|
||||
void* user_i = static_cast<void*>(const_cast<T*>(input.flat<T>().data()));
|
||||
bool input_in_mkl_format = mkl_context.input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format && !input.dims()) { // handle the case of a scalar
|
||||
const TensorShape& o_shape = input.shape();
|
||||
Tensor* out_tensor = nullptr;
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &out_tensor, o_shape,
|
||||
mkl_context.output_shape);
|
||||
void* out_o = static_cast<void*>(out_tensor->flat<T>().data());
|
||||
(static_cast<T*>(out_o))[0] =
|
||||
std::max((static_cast<T*>(user_i))[0], static_cast<T>(0));
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate size, stride for input if input is in MKL format.
|
||||
if (input_in_mkl_format) {
|
||||
mkl_context.in_dims = mkl_context.input_shape.GetDimension();
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = mkl_context.input_shape.GetSizes()[i];
|
||||
mkl_context.in_strides[i] = mkl_context.input_shape.GetStrides()[i];
|
||||
}
|
||||
} else {
|
||||
mkl_context.in_dims = input.dims();
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = input.dim_size((mkl_context.in_dims - 1) - i);
|
||||
}
|
||||
mkl_context.in_strides[0] = 1;
|
||||
for (int i = 1; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_strides[i] =
|
||||
mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
float negative_slope = 0.0;
|
||||
mkl_context.MklCreateInputLayouts(context);
|
||||
CHECK_EQ(dnnReLUCreateForward_F32(&mkl_context.prim_relu_fwd, NULL,
|
||||
mkl_context.lt_input, negative_slope),
|
||||
E_SUCCESS);
|
||||
|
||||
Tensor* output = nullptr;
|
||||
|
||||
if (input_in_mkl_format) {
|
||||
TensorShape tf_shape;
|
||||
mkl_context.output_shape.SetMklTensor(true);
|
||||
mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_fwd,
|
||||
dnnResourceDst);
|
||||
mkl_context.output_shape.SetTfLayout(
|
||||
mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
|
||||
mkl_context.output_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap());
|
||||
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_context.output_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output, tf_shape,
|
||||
mkl_context.output_shape);
|
||||
} else {
|
||||
const TensorShape& o_shape = input.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, o_shape,
|
||||
mkl_context.output_shape);
|
||||
}
|
||||
|
||||
void* user_o = static_cast<void*>(const_cast<T*>(output->flat<T>().data()));
|
||||
|
||||
mkl_context.relu_res[dnnResourceDst] = user_o;
|
||||
mkl_context.relu_res[dnnResourceSrc] = user_i;
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_fwd, mkl_context.relu_res),
|
||||
E_SUCCESS);
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
private:
|
||||
typedef struct {
|
||||
int in_dims;
|
||||
size_t* in_sizes;
|
||||
size_t* in_strides;
|
||||
MklShape input_shape, output_shape;
|
||||
dnnPrimitive_t prim_relu_fwd = nullptr;
|
||||
void* relu_res[dnnResourceNumber];
|
||||
dnnLayout_t lt_input = nullptr;
|
||||
|
||||
void MklCleanup() {
|
||||
bool input_in_mkl_format = input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format) {
|
||||
dnnLayoutDelete_F32(lt_input);
|
||||
free(in_sizes);
|
||||
free(in_strides);
|
||||
}
|
||||
dnnDelete_F32(prim_relu_fwd);
|
||||
}
|
||||
|
||||
void MklCreateInputLayouts(OpKernelContext* context) {
|
||||
bool input_in_mkl_format = input_shape.IsMklTensor();
|
||||
if (!input_in_mkl_format) {
|
||||
CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
|
||||
}
|
||||
}
|
||||
} MklReluOpContext;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
class MklReluGradOp : public OpKernel {
|
||||
public:
|
||||
~MklReluGradOp() {}
|
||||
|
||||
explicit MklReluGradOp(OpKernelConstruction* context) : OpKernel(context) {}
|
||||
|
||||
void Compute(OpKernelContext* context) override;
|
||||
|
||||
private:
|
||||
typedef struct {
|
||||
int in_dims;
|
||||
size_t* in_sizes;
|
||||
size_t* in_strides;
|
||||
MklShape input_shape, grad_shape, output_shape;
|
||||
void* relu_res[dnnResourceNumber];
|
||||
dnnPrimitive_t prim_relu_bwd;
|
||||
dnnLayout_t lt_input, lt_grad;
|
||||
|
||||
void MklPrepareReluGradInputs(OpKernelContext* context,
|
||||
Tensor* mkl_tmp_grad_buf_tensor,
|
||||
Tensor* mkl_tmp_input_buf_tensor) {
|
||||
dnnPrimitive_t cv_user_to_reluB_input, cv_user_to_reluB_grad;
|
||||
dnnLayout_t mkl_lt_internal_input, mkl_lt_internal_grad;
|
||||
|
||||
const Tensor& g = MklGetInput(context, 0);
|
||||
const Tensor& a = MklGetInput(context, 1);
|
||||
|
||||
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
|
||||
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(
|
||||
&mkl_lt_internal_grad, prim_relu_bwd, dnnResourceDiffDst),
|
||||
E_SUCCESS);
|
||||
|
||||
CHECK_EQ(dnnLayoutCreateFromPrimitive_F32(&mkl_lt_internal_input,
|
||||
prim_relu_bwd, dnnResourceSrc),
|
||||
E_SUCCESS);
|
||||
|
||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_grad, lt_grad)) {
|
||||
AllocTmpBuffer(context, mkl_tmp_grad_buf_tensor, mkl_lt_internal_grad,
|
||||
&relu_res[dnnResourceDiffDst]);
|
||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_grad, lt_grad,
|
||||
mkl_lt_internal_grad),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_grad, user_g,
|
||||
relu_res[dnnResourceDiffDst]),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(cv_user_to_reluB_grad);
|
||||
} else {
|
||||
relu_res[dnnResourceDiffDst] = user_g;
|
||||
}
|
||||
|
||||
if (!dnnLayoutCompare_F32(mkl_lt_internal_input, lt_input)) {
|
||||
AllocTmpBuffer(context, mkl_tmp_input_buf_tensor, mkl_lt_internal_input,
|
||||
&relu_res[dnnResourceSrc]);
|
||||
CHECK_EQ(dnnConversionCreate_F32(&cv_user_to_reluB_input, lt_input,
|
||||
mkl_lt_internal_input),
|
||||
E_SUCCESS);
|
||||
CHECK_EQ(dnnConversionExecute_F32(cv_user_to_reluB_input, user_i,
|
||||
relu_res[dnnResourceSrc]),
|
||||
E_SUCCESS);
|
||||
dnnDelete_F32(cv_user_to_reluB_input);
|
||||
} else {
|
||||
relu_res[dnnResourceSrc] = user_i;
|
||||
}
|
||||
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_input);
|
||||
dnnLayoutDelete_F32(mkl_lt_internal_grad);
|
||||
}
|
||||
|
||||
void MklCreateInputLayouts(OpKernelContext* context) {
|
||||
bool grad_is_mkl = grad_shape.IsMklTensor();
|
||||
bool input_is_mkl = input_shape.IsMklTensor();
|
||||
if (!input_is_mkl) {
|
||||
CHECK_EQ(dnnLayoutCreate_F32(<_input, in_dims, in_sizes, in_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_input = static_cast<dnnLayout_t>(input_shape.GetCurLayout());
|
||||
}
|
||||
|
||||
if (!grad_is_mkl) {
|
||||
CHECK_EQ(dnnLayoutCreate_F32(<_grad, in_dims, in_sizes, in_strides),
|
||||
E_SUCCESS);
|
||||
} else {
|
||||
lt_grad = static_cast<dnnLayout_t>(grad_shape.GetCurLayout());
|
||||
}
|
||||
}
|
||||
|
||||
void MklCleanup() {
|
||||
bool grad_is_mkl = grad_shape.IsMklTensor();
|
||||
bool input_is_mkl = input_shape.IsMklTensor();
|
||||
dnnDelete_F32(prim_relu_bwd);
|
||||
if (!input_is_mkl) {
|
||||
dnnLayoutDelete_F32(lt_input);
|
||||
free(in_sizes);
|
||||
free(in_strides);
|
||||
}
|
||||
if (!grad_is_mkl) {
|
||||
dnnLayoutDelete_F32(lt_grad);
|
||||
}
|
||||
}
|
||||
} MklReluGradOpContext;
|
||||
};
|
||||
|
||||
template <typename Device, typename T>
|
||||
|
||||
void MklReluGradOp<Device, T>::Compute(OpKernelContext* context) {
|
||||
MklReluGradOpContext mkl_context;
|
||||
const Tensor& g = MklGetInput(context, 0);
|
||||
const Tensor& a = MklGetInput(context, 1);
|
||||
|
||||
void* user_i = static_cast<void*>(const_cast<T*>(a.flat<T>().data()));
|
||||
void* user_g = static_cast<void*>(const_cast<T*>(g.flat<T>().data()));
|
||||
|
||||
GetMklShape(context, 0, &mkl_context.grad_shape);
|
||||
GetMklShape(context, 1, &mkl_context.input_shape);
|
||||
|
||||
bool grad_is_mkl = mkl_context.grad_shape.IsMklTensor();
|
||||
bool input_is_mkl = mkl_context.input_shape.IsMklTensor();
|
||||
if (!input_is_mkl && !grad_is_mkl &&
|
||||
!MklReluHelpers::ValidateSameSize(context, g, a))
|
||||
return;
|
||||
Tensor* output = nullptr;
|
||||
if (!input_is_mkl && !grad_is_mkl &&
|
||||
!a.dims()) { // handle the case of a scalar
|
||||
// Allocate space for g and
|
||||
const TensorShape& g_shape = g.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, g_shape,
|
||||
mkl_context.output_shape);
|
||||
void* out_o = static_cast<void*>(output->flat<T>().data());
|
||||
(static_cast<T*>(out_o))[0] =
|
||||
(static_cast<T*>(user_g))[0] * ((static_cast<T*>(user_i))[0] > 0);
|
||||
return;
|
||||
}
|
||||
|
||||
// Generate size, stride for input if input/grad is in MKL format.
|
||||
if (grad_is_mkl || input_is_mkl) {
|
||||
const MklShape* tmp_mkl_shape =
|
||||
(grad_is_mkl) ? &mkl_context.grad_shape : &mkl_context.input_shape;
|
||||
|
||||
mkl_context.in_dims = tmp_mkl_shape->GetDimension();
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = tmp_mkl_shape->GetSizes()[i];
|
||||
mkl_context.in_strides[i] = tmp_mkl_shape->GetStrides()[i];
|
||||
}
|
||||
} else {
|
||||
mkl_context.in_dims = g.dims();
|
||||
mkl_context.in_strides = new size_t[mkl_context.in_dims];
|
||||
mkl_context.in_sizes = new size_t[mkl_context.in_dims];
|
||||
|
||||
for (int i = 0; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_sizes[i] = g.dim_size((mkl_context.in_dims - 1) - i);
|
||||
}
|
||||
mkl_context.in_strides[0] = 1;
|
||||
for (int i = 1; i < mkl_context.in_dims; i++) {
|
||||
mkl_context.in_strides[i] =
|
||||
mkl_context.in_strides[i - 1] * mkl_context.in_sizes[i - 1];
|
||||
}
|
||||
}
|
||||
|
||||
mkl_context.MklCreateInputLayouts(context);
|
||||
float negative_slope = 0.0;
|
||||
CHECK_EQ(dnnReLUCreateBackward_F32(&mkl_context.prim_relu_bwd, NULL,
|
||||
mkl_context.lt_grad, mkl_context.lt_input,
|
||||
negative_slope),
|
||||
E_SUCCESS);
|
||||
Tensor mkl_tmp_grad_buf_tensor, mkl_tmp_input_buf_tensor;
|
||||
mkl_context.MklPrepareReluGradInputs(context, &mkl_tmp_grad_buf_tensor,
|
||||
&mkl_tmp_input_buf_tensor);
|
||||
|
||||
if (input_is_mkl ||
|
||||
grad_is_mkl) { /*if grad or input are MKL leave it in MKL*/
|
||||
TensorShape tf_shape;
|
||||
mkl_context.output_shape.SetMklTensor(true);
|
||||
mkl_context.output_shape.SetMklLayout(mkl_context.prim_relu_bwd,
|
||||
dnnResourceDiffSrc);
|
||||
mkl_context.output_shape.SetTfLayout(
|
||||
mkl_context.in_dims, mkl_context.in_sizes, mkl_context.in_strides);
|
||||
// If input_is_mkl or grad_is_mkl, then we copy strides and sizes from Mkl
|
||||
// shape of one that is in MKL layout.
|
||||
if (grad_is_mkl == true) {
|
||||
mkl_context.output_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.grad_shape.GetTfToMklDimMap());
|
||||
} else {
|
||||
mkl_context.output_shape.SetTfDimOrder(
|
||||
mkl_context.in_dims, mkl_context.input_shape.GetTfToMklDimMap());
|
||||
}
|
||||
|
||||
tf_shape.AddDim(dnnLayoutGetMemorySize_F32(static_cast<dnnLayout_t>(
|
||||
mkl_context.output_shape.GetMklLayout())) /
|
||||
sizeof(T));
|
||||
AllocateOutputSetMklshape(context, 0, &output, tf_shape,
|
||||
mkl_context.output_shape);
|
||||
|
||||
} else {
|
||||
const TensorShape& o_shape = g.shape();
|
||||
mkl_context.output_shape.SetMklTensor(false);
|
||||
AllocateOutputSetMklshape(context, 0, &output, o_shape,
|
||||
mkl_context.output_shape);
|
||||
}
|
||||
|
||||
mkl_context.relu_res[dnnResourceDiffSrc] =
|
||||
static_cast<void*>(output->flat<T>().data());
|
||||
|
||||
CHECK_EQ(dnnExecute_F32(mkl_context.prim_relu_bwd, mkl_context.relu_res),
|
||||
E_SUCCESS);
|
||||
mkl_context.MklCleanup();
|
||||
}
|
||||
|
||||
/* Register DNN kernels for supported operations and supported types - right now
|
||||
* it is only Relu and f32*/
|
||||
#define REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES(type) \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklRelu") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
MklReluOp<CPUDevice, type>); \
|
||||
REGISTER_KERNEL_BUILDER(Name("MklReluGrad") \
|
||||
.Device(DEVICE_CPU) \
|
||||
.TypeConstraint<type>("T") \
|
||||
.Label(mkl_layer_registry::kMklLayerLabel), \
|
||||
MklReluGradOp<CPUDevice, type>);
|
||||
TF_CALL_float(REGISTER_RELU_MKL_SUPPORTED_KERNELS_TYPES);
|
||||
|
||||
} // namespace tensorflow
|
||||
|
||||
#endif // INTEL_MKL
|
||||
|
@ -580,8 +580,7 @@ struct LaunchMaxPooling3dGradGradOp<CPUDevice, T> {
|
||||
*(context->device()->tensorflow_cpu_worker_threads());
|
||||
|
||||
auto shard = [¶ms, &in_mat, &out_mat, &top_diff_mat, &bottom_diff_mat](
|
||||
int64 start, int64 limit) {
|
||||
|
||||
int64 start, int64 limit) {
|
||||
const int32 depth = params.depth;
|
||||
const int32 in_planes = params.tensor_in_planes;
|
||||
const int32 in_rows = params.tensor_in_rows;
|
||||
@ -682,10 +681,9 @@ class MaxPooling3dGradGradOp : public OpKernel {
|
||||
"Pooling is not yet supported on the batch dimension."));
|
||||
const int32 ksize_c = GetTensorDim(ksize_, data_format_, 'C');
|
||||
const int32 stride_c = GetTensorDim(stride_, data_format_, 'C');
|
||||
OP_REQUIRES(
|
||||
context, ksize_c == 1 && stride_c == 1,
|
||||
errors::Unimplemented(
|
||||
"MaxPooling3dGradGrad is not yet supported on the depth dimension."));
|
||||
OP_REQUIRES(context, ksize_c == 1 && stride_c == 1,
|
||||
errors::Unimplemented("MaxPooling3dGradGrad is not yet "
|
||||
"supported on the depth dimension."));
|
||||
}
|
||||
|
||||
void Compute(OpKernelContext* context) override {
|
||||
@ -703,7 +701,7 @@ class MaxPooling3dGradGradOp : public OpKernel {
|
||||
context, out_grad_backprop.dims() == 5,
|
||||
errors::InvalidArgument("out_grad_backprop must be 5-dimensional"));
|
||||
|
||||
Pool3dParameters params{context, ksize_, stride_,
|
||||
Pool3dParameters params{context, ksize_, stride_,
|
||||
padding_, data_format_, tensor_in.shape()};
|
||||
|
||||
Tensor* output = nullptr;
|
||||
@ -736,12 +734,11 @@ class MaxPooling3dGradGradOp : public OpKernel {
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AvgPool3D").Device(DEVICE_##D).TypeConstraint<T>("T"), \
|
||||
Pooling3DOp<D##Device, T, AVG>); \
|
||||
REGISTER_KERNEL_BUILDER( \
|
||||
Name("AvgPool3DGrad") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("orig_input_shape"), \
|
||||
AvgPooling3dGradOp<D##Device, T>);
|
||||
REGISTER_KERNEL_BUILDER(Name("AvgPool3DGrad") \
|
||||
.Device(DEVICE_##D) \
|
||||
.TypeConstraint<T>("T") \
|
||||
.HostMemory("orig_input_shape"), \
|
||||
AvgPooling3dGradOp<D##Device, T>);
|
||||
|
||||
#define REGISTER_CPU_KERNELS(T) REGISTER_KERNELS(CPU, T)
|
||||
TF_CALL_float(REGISTER_CPU_KERNELS);
|
||||
@ -835,8 +832,7 @@ struct LaunchMaxPooling3dGradGradOp<GPUDevice, T> {
|
||||
};
|
||||
|
||||
#define REGISTER_GPU_KERNELS(T) REGISTER_KERNELS(GPU, T)
|
||||
TF_CALL_float(REGISTER_GPU_KERNELS)
|
||||
TF_CALL_half(REGISTER_GPU_KERNELS)
|
||||
TF_CALL_float(REGISTER_GPU_KERNELS) TF_CALL_half(REGISTER_GPU_KERNELS)
|
||||
#undef REGISTER_GPU_KERNELS
|
||||
|
||||
#endif // GOOGLE_CUDA
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
|
||||
#define EIGEN_USE_GPU
|
||||
|
||||
#include "tensorflow/core/kernels/pooling_ops_3d_gpu.h"
|
||||
#include "tensorflow/core/framework/register_types.h"
|
||||
#include "tensorflow/core/kernels/pooling_ops_3d_gpu.h"
|
||||
#include "tensorflow/core/util/cuda_kernel_helper.h"
|
||||
#include "tensorflow/core/util/tensor_format.h"
|
||||
|
||||
@ -159,12 +159,11 @@ bool MaxPool3dGradBackward<T>::operator()(
|
||||
bottom_diff);
|
||||
}
|
||||
return d.ok();
|
||||
};
|
||||
}
|
||||
|
||||
} // namespace functor
|
||||
|
||||
#define DEFINE_GPU_SPECS(T) \
|
||||
template struct functor::MaxPool3dGradBackward<T>;
|
||||
#define DEFINE_GPU_SPECS(T) template struct functor::MaxPool3dGradBackward<T>;
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
||||
#undef DEFINE_GPU_SPECS
|
||||
|
||||
|
@ -373,7 +373,7 @@ void DnnPoolingGradOp<T>::Compute(
|
||||
}
|
||||
}
|
||||
|
||||
#define DEFINE_DNN_OPS(T) \
|
||||
#define DEFINE_DNN_OPS(T) \
|
||||
template class DnnPoolingOp<T>; \
|
||||
template class DnnPoolingGradOp<T>;
|
||||
TF_CALL_GPU_NUMBER_TYPES(DEFINE_DNN_OPS)
|
||||
|
@ -35,9 +35,9 @@ void dummy_xsmm_conv2d_ensure_file_is_not_empty(void);
|
||||
#include "tensorflow/core/lib/core/blocking_counter.h"
|
||||
#include "tensorflow/core/lib/core/threadpool.h"
|
||||
|
||||
#include "libxsmm_main.h" // TODO(bsteiner): API to avoid incl. header from src/
|
||||
#include "include/libxsmm_cpuid.h"
|
||||
#include "include/libxsmm_malloc.h"
|
||||
#include "libxsmm_main.h" // TODO: API to avoid incl. header from src/
|
||||
|
||||
namespace tensorflow {
|
||||
|
||||
@ -72,7 +72,6 @@ bool CanUseXsmmConv2D(const libxsmm_dnn_conv_desc& desc,
|
||||
return true;
|
||||
}
|
||||
|
||||
|
||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||
|
||||
namespace functor {
|
||||
@ -83,25 +82,34 @@ static void chk_libxsmm_err(libxsmm_dnn_err_t status, string msg) {
|
||||
}
|
||||
}
|
||||
|
||||
LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, int S, int C, int K,int blocksifm, int blocksofm, int ifmblock,int ofmblock, int start, int end)
|
||||
{
|
||||
LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C,K);
|
||||
LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm,R,S,ifmblock, ofmblock);
|
||||
int r, s, k,c, v1,v2;
|
||||
|
||||
for (k = start; k < end ; k++ ) {
|
||||
for(c = 0; c < blocksifm;c++){
|
||||
for ( r = 0; r < R; r++ ) {
|
||||
for ( s = 0; s < S; s++ ){
|
||||
for ( v1 = c*ifmblock; v1 < std::min(C,(c+1)*ifmblock) ; v1++ ) {
|
||||
for ( v2 = k*ofmblock; v2 < std::min(K, (k+1)*ofmblock); v2++ )
|
||||
LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
|
||||
for ( v2 = K; v2 < (k+1)*ofmblock ; v2++ )
|
||||
LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f;
|
||||
}
|
||||
for ( v1 = C; v1 < (c+1)*ifmblock ; v1++ ) {
|
||||
for ( v2 = k*ofmblock; v2 < (k+1)*ofmblock; v2++ )
|
||||
LIBXSMM_VLA_ACCESS(6, output, k,c, r, s,v1- c*ifmblock,v2-k*ofmblock, blocksifm, R, S,ifmblock,ofmblock) = 0.0f;
|
||||
LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float* kcrs, int R,
|
||||
int S, int C, int K, int blocksifm,
|
||||
int blocksofm, int ifmblock,
|
||||
int ofmblock, int start, int end) {
|
||||
LIBXSMM_VLA_DECL(4, const float, input, rsck, S, C, K);
|
||||
LIBXSMM_VLA_DECL(6, float, output, kcrs, blocksifm, R, S, ifmblock, ofmblock);
|
||||
int r, s, k, c, v1, v2;
|
||||
|
||||
for (k = start; k < end; k++) {
|
||||
for (c = 0; c < blocksifm; c++) {
|
||||
for (r = 0; r < R; r++) {
|
||||
for (s = 0; s < S; s++) {
|
||||
for (v1 = c * ifmblock; v1 < std::min(C, (c + 1) * ifmblock); v1++) {
|
||||
for (v2 = k * ofmblock; v2 < std::min(K, (k + 1) * ofmblock); v2++)
|
||||
LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
|
||||
v2 - k * ofmblock, blocksifm, R, S, ifmblock,
|
||||
ofmblock) =
|
||||
LIBXSMM_VLA_ACCESS(4, input, r, s, v1, v2, S, C, K);
|
||||
for (v2 = K; v2 < (k + 1) * ofmblock; v2++)
|
||||
LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
|
||||
v2 - k * ofmblock, blocksifm, R, S, ifmblock,
|
||||
ofmblock) = 0.0f;
|
||||
}
|
||||
for (v1 = C; v1 < (c + 1) * ifmblock; v1++) {
|
||||
for (v2 = k * ofmblock; v2 < (k + 1) * ofmblock; v2++)
|
||||
LIBXSMM_VLA_ACCESS(6, output, k, c, r, s, v1 - c * ifmblock,
|
||||
v2 - k * ofmblock, blocksifm, R, S, ifmblock,
|
||||
ofmblock) = 0.0f;
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -109,47 +117,28 @@ LIBXSMM_INLINE void copy_RSCK_to_custom(const float* rsck, float *kcrs, int R, i
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
class libxsmm_dnn_conv_desc_wrap {
|
||||
public:
|
||||
const libxsmm_dnn_conv_desc d;
|
||||
|
||||
class libxsmm_dnn_conv_desc_wrap{
|
||||
public:
|
||||
const libxsmm_dnn_conv_desc d;
|
||||
|
||||
libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc &d_) : d(d_){
|
||||
}
|
||||
bool operator==(const libxsmm_dnn_conv_desc_wrap &w) const{
|
||||
return( d.N == w.d.N &&
|
||||
d.C == w.d.C &&
|
||||
d.H == w.d.H &&
|
||||
d.W == w.d.W &&
|
||||
d.K == w.d.K &&
|
||||
d.R == w.d.R &&
|
||||
d.S == w.d.S &&
|
||||
d.u == w.d.u &&
|
||||
d.v == w.d.v &&
|
||||
d.pad_h == w.d.pad_h &&
|
||||
d.pad_w == w.d.pad_w
|
||||
);
|
||||
}
|
||||
libxsmm_dnn_conv_desc_wrap(const libxsmm_dnn_conv_desc& d_) : d(d_) {}
|
||||
bool operator==(const libxsmm_dnn_conv_desc_wrap& w) const {
|
||||
return (d.N == w.d.N && d.C == w.d.C && d.H == w.d.H && d.W == w.d.W &&
|
||||
d.K == w.d.K && d.R == w.d.R && d.S == w.d.S && d.u == w.d.u &&
|
||||
d.v == w.d.v && d.pad_h == w.d.pad_h && d.pad_w == w.d.pad_w);
|
||||
}
|
||||
};
|
||||
|
||||
|
||||
struct HashFunction{
|
||||
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap & w) const{
|
||||
|
||||
|
||||
|
||||
struct HashFunction {
|
||||
std::size_t operator()(const libxsmm_dnn_conv_desc_wrap& w) const {
|
||||
// unsigned char ptr[sizeof(&w.d)];
|
||||
|
||||
//unsigned char ptr[sizeof(&w.d)];
|
||||
|
||||
|
||||
//memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d))
|
||||
|
||||
// memcpy(ptr, (unsigned char *)&w.d, sizeof(&w.d))
|
||||
|
||||
//
|
||||
/*
|
||||
std::ostringstream N,C,H,W,K,R,S,u,v,padh,padw;
|
||||
|
||||
|
||||
N << w.d.N; C << w.d.C;
|
||||
H << w.d.H; W << w.d.W;
|
||||
K << w.d.K; R << w.d.R;
|
||||
@ -167,47 +156,53 @@ struct HashFunction{
|
||||
//
|
||||
//
|
||||
*/
|
||||
return ( std::hash<unsigned long long>()((unsigned long long)&(w.d)));
|
||||
return (std::hash<unsigned long long>()((unsigned long long)&(w.d)));
|
||||
}
|
||||
};
|
||||
|
||||
class handles{
|
||||
public:
|
||||
libxsmm_dnn_layer* find( const libxsmm_dnn_conv_desc_wrap &w) {
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction>::iterator i = libxsmm_handles.find(w);
|
||||
if (i == libxsmm_handles.end()){
|
||||
libxsmm_dnn_err_t status;
|
||||
libxsmm_dnn_layer* libxsmm_handle = libxsmm_dnn_create_conv_layer(w.d, &status);
|
||||
chk_libxsmm_err(status, "Create handle");
|
||||
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
|
||||
return libxsmm_handle;
|
||||
}
|
||||
else
|
||||
return i->second;
|
||||
|
||||
class handles {
|
||||
public:
|
||||
libxsmm_dnn_layer* find(const libxsmm_dnn_conv_desc_wrap& w) {
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
||||
HashFunction>::iterator i = libxsmm_handles.find(w);
|
||||
if (i == libxsmm_handles.end()) {
|
||||
libxsmm_dnn_err_t status;
|
||||
libxsmm_dnn_layer* libxsmm_handle =
|
||||
libxsmm_dnn_create_conv_layer(w.d, &status);
|
||||
chk_libxsmm_err(status, "Create handle");
|
||||
libxsmm_handles.insert(std::make_pair(w, libxsmm_handle));
|
||||
return libxsmm_handle;
|
||||
} else {
|
||||
return i->second;
|
||||
}
|
||||
~handles(){
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction>::iterator i;
|
||||
for (i= libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
|
||||
}
|
||||
~handles() {
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
||||
HashFunction>::iterator i;
|
||||
for (i = libxsmm_handles.begin(); i != libxsmm_handles.end(); i++)
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(i->second),
|
||||
"Destroy handle");
|
||||
}
|
||||
private:
|
||||
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap , libxsmm_dnn_layer*, HashFunction> libxsmm_handles;
|
||||
|
||||
"Destroy handle");
|
||||
}
|
||||
|
||||
private:
|
||||
std::unordered_map<libxsmm_dnn_conv_desc_wrap, libxsmm_dnn_layer*,
|
||||
HashFunction>
|
||||
libxsmm_handles;
|
||||
};
|
||||
|
||||
static handles libxsmm_handles;
|
||||
|
||||
//#define LIBXSMM_DETAILED_TIMING
|
||||
// #define LIBXSMM_DETAILED_TIMING
|
||||
|
||||
template <typename InputPtr, typename FilterPtr, typename OutputPtr>
|
||||
static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
const libxsmm_dnn_conv_desc& desc,
|
||||
libxsmm_dnn_compute_kind kind, InputPtr input,
|
||||
FilterPtr filter, OutputPtr output) {
|
||||
libxsmm_dnn_compute_kind kind,
|
||||
InputPtr input, FilterPtr filter,
|
||||
OutputPtr output) {
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6, l_tick7, l_tick8, l_tick9, l_tick10;
|
||||
unsigned long long l_tick1, l_tick2, l_tick3, l_tick4, l_tick5, l_tick6,
|
||||
l_tick7, l_tick8, l_tick9, l_tick10;
|
||||
l_tick1 = libxsmm_timer_tick();
|
||||
#endif
|
||||
// setup scoped allocator, which adopts the allocator from the context
|
||||
@ -216,14 +211,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
libxsmm_dnn_layer* libxsmm_handle;
|
||||
libxsmm_dnn_conv_desc_wrap w(desc);
|
||||
void* scratch;
|
||||
|
||||
//if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
|
||||
// if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
libxsmm_handle = libxsmm_handles.find(w);
|
||||
//else{
|
||||
// else{
|
||||
// libxsmm_handle = libxsmm_dnn_create_conv_layer(desc, &status);
|
||||
// chk_libxsmm_err(status, "Create handle");
|
||||
//}
|
||||
|
||||
|
||||
status = libxsmm_dnn_get_codegen_success(libxsmm_handle, kind);
|
||||
if (status == LIBXSMM_DNN_WARN_FALLBACK) {
|
||||
chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||
@ -241,12 +236,16 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
#endif
|
||||
|
||||
int ifmblock = (libxsmm_handle->ifmblock);
|
||||
int ofmblock = (libxsmm_handle->ofmblock);
|
||||
int ofmblock = (libxsmm_handle->ofmblock);
|
||||
|
||||
int blocksifm = desc.C%ifmblock ==0 ? desc.C/ifmblock :desc.C/ifmblock + 1;
|
||||
int blocksofm = desc.K%ofmblock ==0 ? desc.K/ofmblock :desc.K/ofmblock + 1;
|
||||
float *native_filter = (float*)libxsmm_aligned_scratch( blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float), 2097152);
|
||||
|
||||
int blocksifm =
|
||||
desc.C % ifmblock == 0 ? desc.C / ifmblock : desc.C / ifmblock + 1;
|
||||
int blocksofm =
|
||||
desc.K % ofmblock == 0 ? desc.K / ofmblock : desc.K / ofmblock + 1;
|
||||
float* native_filter =
|
||||
(float*)libxsmm_aligned_scratch(blocksofm * blocksifm * desc.R * desc.S *
|
||||
ifmblock * ofmblock * sizeof(float),
|
||||
2097152);
|
||||
|
||||
const DeviceBase::CpuWorkerThreads* worker_threads =
|
||||
ctx->device()->tensorflow_cpu_worker_threads();
|
||||
@ -254,90 +253,111 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
int num_threads = worker_threads->num_threads;
|
||||
|
||||
#if 1
|
||||
if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || kind == LIBXSMM_DNN_COMPUTE_KIND_BWD){
|
||||
if(blocksofm > num_threads){
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
||||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
if (blocksofm > num_threads) {
|
||||
int work = blocksofm;
|
||||
BlockingCounter count(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
worker_threads->workers->Schedule([=, &count]() {
|
||||
int start = work/num_threads*i;
|
||||
int end = (start + work/num_threads) > work ? work: start + work/num_threads;
|
||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock,start, end);
|
||||
count.DecrementCount();
|
||||
int start = work / num_threads * i;
|
||||
int end = (start + work / num_threads) > work
|
||||
? work
|
||||
: start + work / num_threads;
|
||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
||||
desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
|
||||
start, end);
|
||||
count.DecrementCount();
|
||||
});
|
||||
}
|
||||
count.Wait();
|
||||
}
|
||||
else{
|
||||
|
||||
} else {
|
||||
int work = blocksofm;
|
||||
int num_threads = work;
|
||||
|
||||
|
||||
BlockingCounter count(num_threads);
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
worker_threads->workers->Schedule([=, &count]() {
|
||||
int start = i;
|
||||
int end = i+1;
|
||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S,desc.C, desc.K,blocksifm,blocksofm,ifmblock,ofmblock, start, end);
|
||||
count.DecrementCount();
|
||||
int start = i;
|
||||
int end = i + 1;
|
||||
copy_RSCK_to_custom(filter, native_filter, desc.R, desc.S, desc.C,
|
||||
desc.K, blocksifm, blocksofm, ifmblock, ofmblock,
|
||||
start, end);
|
||||
count.DecrementCount();
|
||||
});
|
||||
}
|
||||
count.Wait();
|
||||
}
|
||||
}
|
||||
//Added: for weight update
|
||||
else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD){
|
||||
libxsmm_filter = libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter, LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link filter");//weight update is in RSCK as filter should be returned in RSCK format
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
// Added: for weight update
|
||||
libxsmm_filter =
|
||||
libxsmm_dnn_link_filter(libxsmm_handle, LIBXSMM_DNN_FILTER, filter,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_RSCK_PTR, &status);
|
||||
chk_libxsmm_err(status,
|
||||
"Link filter"); // weight update is in RSCK as
|
||||
// filter should be returned in RSCK
|
||||
// format
|
||||
}
|
||||
#else
|
||||
memset( native_filter, 0, blocksofm*blocksifm*desc.R*desc.S*ifmblock*ofmblock*sizeof(float));
|
||||
memset(native_filter, 0,
|
||||
blocksofm * blocksifm * desc.R * desc.S * ifmblock * ofmblock *
|
||||
sizeof(float));
|
||||
#endif
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick3 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
libxsmm_input = libxsmm_dnn_link_buffer(
|
||||
libxsmm_handle, LIBXSMM_DNN_INPUT, input, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
libxsmm_input =
|
||||
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_INPUT, input,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link input buffer");
|
||||
libxsmm_output = libxsmm_dnn_link_buffer(
|
||||
libxsmm_handle, LIBXSMM_DNN_OUTPUT, output, LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
libxsmm_output =
|
||||
libxsmm_dnn_link_buffer(libxsmm_handle, LIBXSMM_DNN_OUTPUT, output,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_NHWC_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link output buffer");
|
||||
if(kind == LIBXSMM_DNN_COMPUTE_KIND_FWD || kind == LIBXSMM_DNN_COMPUTE_KIND_BWD){
|
||||
libxsmm_filter = libxsmm_dnn_link_filter(
|
||||
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter, LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link filter");
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD ||
|
||||
kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
libxsmm_filter = libxsmm_dnn_link_filter(
|
||||
libxsmm_handle, LIBXSMM_DNN_FILTER, native_filter,
|
||||
LIBXSMM_DNN_TENSOR_FORMAT_LIBXSMM_PTR, &status);
|
||||
chk_libxsmm_err(status, "Link filter");
|
||||
}
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_output), "Zero output");
|
||||
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
||||
LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"Bind input forward");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"Bind output forward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
||||
LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"Bind output forward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
||||
LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"Bind filter forward");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
chk_libxsmm_err(libxsmm_dnn_zero_buffer(libxsmm_input), "Zero input");
|
||||
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
||||
LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
"Bind input backward");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"Bind output backward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
||||
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"Bind output backward");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
||||
LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"Bind filter backward");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
chk_libxsmm_err(libxsmm_dnn_zero_filter(libxsmm_filter), "Zero filter");
|
||||
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"Bind input weight udpate");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"Bind output weight update");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter, LIBXSMM_DNN_GRADIENT_FILTER),
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_input,
|
||||
LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"Bind input weight update");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_buffer(libxsmm_handle, libxsmm_output,
|
||||
LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"Bind output weight update");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_filter(libxsmm_handle, libxsmm_filter,
|
||||
LIBXSMM_DNN_GRADIENT_FILTER),
|
||||
"Bind filter weight update");
|
||||
} else {
|
||||
/* shouldn't happen */
|
||||
@ -348,9 +368,14 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
#endif
|
||||
|
||||
/* bind scratch */
|
||||
scratch = (void*)libxsmm_aligned_scratch( libxsmm_dnn_get_scratch_size( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, &status ), 2097152);
|
||||
chk_libxsmm_err( status, "scratch allocation" );
|
||||
chk_libxsmm_err( libxsmm_dnn_bind_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch ), "binding scratch" );
|
||||
scratch = (void*)libxsmm_aligned_scratch(
|
||||
libxsmm_dnn_get_scratch_size(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL,
|
||||
&status),
|
||||
2097152);
|
||||
chk_libxsmm_err(status, "scratch allocation");
|
||||
chk_libxsmm_err(libxsmm_dnn_bind_scratch(
|
||||
libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL, scratch),
|
||||
"binding scratch");
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick5 = libxsmm_timer_tick();
|
||||
@ -366,7 +391,7 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
|
||||
#if 1
|
||||
BlockingCounter counter(num_threads);
|
||||
|
||||
|
||||
for (int i = 0; i < num_threads; ++i) {
|
||||
worker_threads->workers->Schedule([=, &counter]() {
|
||||
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, i),
|
||||
@ -376,9 +401,11 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
}
|
||||
counter.Wait();
|
||||
#else
|
||||
#pragma omp parallel
|
||||
#pragma omp parallel
|
||||
{
|
||||
chk_libxsmm_err(libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()), "Worker");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_execute_st(libxsmm_handle, kind, 0, omp_get_thread_num()),
|
||||
"Worker");
|
||||
}
|
||||
#endif
|
||||
|
||||
@ -387,7 +414,7 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
#endif
|
||||
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
libxsmm_dnn_reduce_wu_filters( libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER );
|
||||
libxsmm_dnn_reduce_wu_filters(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER);
|
||||
}
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
@ -395,19 +422,39 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
#endif
|
||||
|
||||
/* clean up */
|
||||
chk_libxsmm_err( libxsmm_dnn_release_scratch( libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL ), "release scratch" );
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_scratch(libxsmm_handle, LIBXSMM_DNN_COMPUTE_KIND_ALL),
|
||||
"release scratch");
|
||||
if (kind == LIBXSMM_DNN_COMPUTE_KIND_FWD) {
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT ), "release output" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"release input");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_OUTPUT),
|
||||
"release output");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"release filter");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_BWD) {
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT ), "release input" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER ), "release filter" );
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_INPUT),
|
||||
"release input");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"release output");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_REGULAR_FILTER),
|
||||
"release filter");
|
||||
} else if (kind == LIBXSMM_DNN_COMPUTE_KIND_UPD) {
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT ), "release input" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_buffer( libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT ), "release output" );
|
||||
chk_libxsmm_err( libxsmm_dnn_release_filter( libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER ), "release filter" );
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_REGULAR_INPUT),
|
||||
"release input");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_buffer(libxsmm_handle, LIBXSMM_DNN_GRADIENT_OUTPUT),
|
||||
"release output");
|
||||
chk_libxsmm_err(
|
||||
libxsmm_dnn_release_filter(libxsmm_handle, LIBXSMM_DNN_GRADIENT_FILTER),
|
||||
"release filter");
|
||||
} else {
|
||||
/* shouldn't happen */
|
||||
}
|
||||
@ -418,9 +465,9 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick9 = libxsmm_timer_tick();
|
||||
#endif
|
||||
|
||||
//if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
//chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||
|
||||
// if(kind != LIBXSMM_DNN_COMPUTE_KIND_FWD)
|
||||
// chk_libxsmm_err(libxsmm_dnn_destroy_conv_layer(libxsmm_handle),
|
||||
// "Destroy handle");
|
||||
|
||||
libxsmm_free(native_filter);
|
||||
@ -428,17 +475,20 @@ static bool CallLibxsmmConvGeneric(OpKernelContext* ctx,
|
||||
|
||||
#if defined(LIBXSMM_DETAILED_TIMING)
|
||||
l_tick10 = libxsmm_timer_tick();
|
||||
printf("time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, %f, %f, %f\n", desc.N, desc.C, desc.K, desc.R, desc.S,
|
||||
libxsmm_timer_duration(l_tick1, l_tick2),
|
||||
libxsmm_timer_duration(l_tick2, l_tick3),
|
||||
libxsmm_timer_duration(l_tick3, l_tick4),
|
||||
libxsmm_timer_duration(l_tick4, l_tick5),
|
||||
libxsmm_timer_duration(l_tick5, l_tick6),
|
||||
libxsmm_timer_duration(l_tick6, l_tick7),
|
||||
libxsmm_timer_duration(l_tick7, l_tick8),
|
||||
libxsmm_timer_duration(l_tick8, l_tick9),
|
||||
libxsmm_timer_duration(l_tick9, l_tick10),
|
||||
libxsmm_timer_duration(l_tick1, l_tick10) );
|
||||
printf(
|
||||
"time for convolution (%i, %i, %i, %i, %i): %f, %f, %f, %f, %f, %f, %f, "
|
||||
"%f, %f, %f\n",
|
||||
desc.N, desc.C, desc.K, desc.R, desc.S,
|
||||
libxsmm_timer_duration(l_tick1, l_tick2),
|
||||
libxsmm_timer_duration(l_tick2, l_tick3),
|
||||
libxsmm_timer_duration(l_tick3, l_tick4),
|
||||
libxsmm_timer_duration(l_tick4, l_tick5),
|
||||
libxsmm_timer_duration(l_tick5, l_tick6),
|
||||
libxsmm_timer_duration(l_tick6, l_tick7),
|
||||
libxsmm_timer_duration(l_tick7, l_tick8),
|
||||
libxsmm_timer_duration(l_tick8, l_tick9),
|
||||
libxsmm_timer_duration(l_tick9, l_tick10),
|
||||
libxsmm_timer_duration(l_tick1, l_tick10));
|
||||
#endif
|
||||
|
||||
return true; // Succeeded
|
||||
@ -448,8 +498,8 @@ template <typename T>
|
||||
struct XsmmFwdConv2D<CPUDevice, T> {
|
||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||
const T* input, const T* filter, T* output) {
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD, input,
|
||||
filter, output);
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_FWD,
|
||||
input, filter, output);
|
||||
}
|
||||
};
|
||||
|
||||
@ -457,8 +507,8 @@ template <typename T>
|
||||
struct XsmmBkwInputConv2D<CPUDevice, T> {
|
||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||
T* input, const T* filter, const T* output) {
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD, input,
|
||||
filter, output);
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_BWD,
|
||||
input, filter, output);
|
||||
}
|
||||
};
|
||||
|
||||
@ -466,8 +516,8 @@ template <typename T>
|
||||
struct XsmmBkwFilterConv2D<CPUDevice, T> {
|
||||
bool operator()(OpKernelContext* ctx, const libxsmm_dnn_conv_desc& desc,
|
||||
const T* input, T* filter, const T* output) {
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD, input,
|
||||
filter, output);
|
||||
return CallLibxsmmConvGeneric(ctx, desc, LIBXSMM_DNN_COMPUTE_KIND_UPD,
|
||||
input, filter, output);
|
||||
}
|
||||
};
|
||||
|
||||
|
@ -47,7 +47,7 @@ Status InputBuffer::ReadLine(string* result) {
|
||||
Status s;
|
||||
do {
|
||||
size_t buf_remain = limit_ - pos_;
|
||||
char* newline = (char*)memchr(pos_, '\n', buf_remain);
|
||||
char* newline = static_cast<char*>(memchr(pos_, '\n', buf_remain));
|
||||
if (newline != nullptr) {
|
||||
size_t result_len = newline - pos_;
|
||||
result->append(pos_, result_len);
|
||||
|
@ -17,6 +17,7 @@ limitations under the License.
|
||||
|
||||
#include <string.h>
|
||||
#include <sys/types.h>
|
||||
#include <zlib.h>
|
||||
#include <string>
|
||||
#include <utility>
|
||||
#include <vector>
|
||||
@ -152,7 +153,10 @@ bool DecodeHeader(StringPiece png_string, int* width, int* height,
|
||||
if (components != NULL) {
|
||||
switch (context.color_type) {
|
||||
case PNG_COLOR_TYPE_PALETTE:
|
||||
*components = (context.info_ptr->valid & PNG_INFO_tRNS) ? 4 : 3;
|
||||
*components =
|
||||
(png_get_valid(context.png_ptr, context.info_ptr, PNG_INFO_tRNS))
|
||||
? 4
|
||||
: 3;
|
||||
break;
|
||||
case PNG_COLOR_TYPE_GRAY:
|
||||
*components = 1;
|
||||
@ -176,8 +180,11 @@ bool DecodeHeader(StringPiece png_string, int* width, int* height,
|
||||
}
|
||||
if (metadata != NULL) {
|
||||
metadata->clear();
|
||||
for (int i = 0; i < context.info_ptr->num_text; i++) {
|
||||
const png_text& text = context.info_ptr->text[i];
|
||||
png_textp text_ptr = NULL;
|
||||
int num_text = 0;
|
||||
png_get_text(context.png_ptr, context.info_ptr, &text_ptr, &num_text);
|
||||
for (int i = 0; i < num_text; i++) {
|
||||
const png_text& text = text_ptr[i];
|
||||
metadata->push_back(std::make_pair(text.key, text.text));
|
||||
}
|
||||
}
|
||||
@ -228,9 +235,10 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels,
|
||||
return false;
|
||||
}
|
||||
if (context->channels == 0) { // Autodetect number of channels
|
||||
context->channels = context->info_ptr->channels;
|
||||
context->channels = png_get_channels(context->png_ptr, context->info_ptr);
|
||||
}
|
||||
const bool has_tRNS = (context->info_ptr->valid & PNG_INFO_tRNS) != 0;
|
||||
const bool has_tRNS =
|
||||
(png_get_valid(context->png_ptr, context->info_ptr, PNG_INFO_tRNS)) != 0;
|
||||
const bool has_alpha = (context->color_type & PNG_COLOR_MASK_ALPHA) != 0;
|
||||
if ((context->channels & 1) == 0) { // We desire alpha
|
||||
if (has_alpha) { // There is alpha
|
||||
@ -268,7 +276,9 @@ bool CommonInitDecode(StringPiece png_string, int desired_channels,
|
||||
const bool want_gray = (context->channels < 3);
|
||||
const bool is_gray = !(context->color_type & PNG_COLOR_MASK_COLOR);
|
||||
if (is_gray) { // upconvert gray to 8-bit if needed.
|
||||
if (context->bit_depth < 8) png_set_gray_1_2_4_to_8(context->png_ptr);
|
||||
if (context->bit_depth < 8) {
|
||||
png_set_expand_gray_1_2_4_to_8(context->png_ptr);
|
||||
}
|
||||
}
|
||||
if (want_gray) { // output is grayscale
|
||||
if (!is_gray)
|
||||
@ -301,7 +311,9 @@ bool CommonFinishDecode(png_bytep data, int row_bytes, DecodeContext* context) {
|
||||
}
|
||||
}
|
||||
|
||||
context->info_ptr->valid |= PNG_INFO_IDAT;
|
||||
// Marks iDAT as valid.
|
||||
png_set_rows(context->png_ptr, context->info_ptr,
|
||||
png_get_rows(context->png_ptr, context->info_ptr));
|
||||
png_read_end(context->png_ptr, context->info_ptr);
|
||||
|
||||
// Clean up.
|
||||
|
@ -1782,6 +1782,69 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "AvgPool"
|
||||
input_arg {
|
||||
name: "value"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NHWC"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NHWC"
|
||||
s: "NCHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "AvgPool3D"
|
||||
input_arg {
|
||||
@ -2097,6 +2160,73 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "AvgPoolGrad"
|
||||
input_arg {
|
||||
name: "orig_input_shape"
|
||||
type: DT_INT32
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NHWC"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NHWC"
|
||||
s: "NCHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "Barrier"
|
||||
output_arg {
|
||||
@ -9755,6 +9885,72 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPool"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NHWC"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NHWC"
|
||||
s: "NCHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPool3D"
|
||||
input_arg {
|
||||
@ -10017,6 +10213,181 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPool3DGrad"
|
||||
input_arg {
|
||||
name: "orig_input"
|
||||
type_attr: "TInput"
|
||||
}
|
||||
input_arg {
|
||||
name: "orig_output"
|
||||
type_attr: "TInput"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 5
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 5
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NDHWC"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NDHWC"
|
||||
s: "NCDHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "TInput"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPool3DGradGrad"
|
||||
input_arg {
|
||||
name: "orig_input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "orig_output"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 5
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 5
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NDHWC"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NDHWC"
|
||||
s: "NCDHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGrad"
|
||||
input_arg {
|
||||
@ -10084,6 +10455,219 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGrad"
|
||||
input_arg {
|
||||
name: "orig_input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "orig_output"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NHWC"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NHWC"
|
||||
s: "NCHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGradGrad"
|
||||
input_arg {
|
||||
name: "orig_input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "orig_output"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NHWC"
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NHWC"
|
||||
s: "NCHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGradGradWithArgmax"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "argmax"
|
||||
type_attr: "Targmax"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Targmax"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGradWithArgmax"
|
||||
input_arg {
|
||||
@ -10148,6 +10732,137 @@ op {
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGradWithArgmax"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "argmax"
|
||||
type_attr: "Targmax"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Targmax"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolWithArgmax"
|
||||
input_arg {
|
||||
name: "input"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "argmax"
|
||||
type_attr: "Targmax"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "Targmax"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_INT64
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolWithArgmax"
|
||||
input_arg {
|
||||
@ -10200,12 +10915,16 @@ op {
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
|
@ -181,7 +181,6 @@ Status MaxPoolGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
}
|
||||
REGISTER_OP_GRADIENT("MaxPool", MaxPoolGrad);
|
||||
|
||||
|
||||
Status MaxPoolGradGrad(const AttrSlice& attrs, FunctionDef* g) {
|
||||
// clang-format off
|
||||
*g = FDH::Define(
|
||||
|
@ -2034,8 +2034,14 @@ op {
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -2259,8 +2265,14 @@ op {
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_HALF
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -8465,7 +8477,7 @@ op {
|
||||
}
|
||||
}
|
||||
summary: "Gather slices from `params` according to `indices`."
|
||||
description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n```python\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n```\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>"
|
||||
description: "`indices` must be an integer tensor of any dimension (usually 0-D or 1-D).\nProduces an output tensor with shape `indices.shape + params.shape[1:]` where:\n\n```python\n # Scalar indices\n output[:, ..., :] = params[indices, :, ... :]\n\n # Vector indices\n output[i, :, ..., :] = params[indices[i], :, ... :]\n\n # Higher rank indices\n output[i, ..., j, :, ... :] = params[indices[i, ..., j], :, ..., :]\n```\n\nIf `indices` is a permutation and `len(indices) == params.shape[0]` then\nthis operation will permute `params` accordingly.\n\n`validate_indices`: DEPRECATED. If this operation is assigned to CPU, values in\n`indices` are always validated to be within range. If assigned to GPU,\nout-of-bound indices result in unspecified behavior (currently the result is\n`0`, but this may become an error in the future).\n\n<div style=\"width:70%; margin:auto; margin-bottom:10px; margin-top:20px;\">\n<img style=\"width:100%\" src=\"../../images/Gather.png\" alt>\n</div>"
|
||||
}
|
||||
op {
|
||||
name: "GatherNd"
|
||||
@ -10574,6 +10586,13 @@ op {
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
@ -10699,12 +10718,12 @@ op {
|
||||
input_arg {
|
||||
name: "orig_input"
|
||||
description: "The original input tensor."
|
||||
type: DT_FLOAT
|
||||
type_attr: "TInput"
|
||||
}
|
||||
input_arg {
|
||||
name: "orig_output"
|
||||
description: "The original output tensor."
|
||||
type: DT_FLOAT
|
||||
type_attr: "TInput"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
@ -10757,6 +10776,34 @@ op {
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT64
|
||||
type: DT_INT32
|
||||
type: DT_UINT8
|
||||
type: DT_UINT16
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_COMPLEX64
|
||||
type: DT_COMPLEX128
|
||||
type: DT_QINT8
|
||||
type: DT_QUINT8
|
||||
type: DT_QINT32
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "TInput"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
@ -10778,6 +10825,86 @@ op {
|
||||
}
|
||||
summary: "Computes gradients of max pooling function."
|
||||
}
|
||||
op {
|
||||
name: "MaxPool3DGradGrad"
|
||||
input_arg {
|
||||
name: "orig_input"
|
||||
description: "The original input tensor."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "orig_output"
|
||||
description: "The original output tensor."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
description: "Output backprop of shape `[batch, depth, rows, cols, channels]`."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Gradients of gradients w.r.t. the input to `max_pool`."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
description: "1-D tensor of length 5. The size of the window for each dimension of\nthe input tensor. Must have `ksize[0] = ksize[4] = 1`."
|
||||
has_minimum: true
|
||||
minimum: 5
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
description: "1-D tensor of length 5. The stride of the sliding window for each\ndimension of `input`. Must have `strides[0] = strides[4] = 1`."
|
||||
has_minimum: true
|
||||
minimum: 5
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
description: "The type of padding algorithm to use."
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NDHWC"
|
||||
}
|
||||
description: "The data format of the input and output data. With the\ndefault format \"NDHWC\", the data is stored in the order of:\n [batch, in_depth, in_height, in_width, in_channels].\nAlternatively, the format could be \"NCDHW\", the data storage order is:\n [batch, in_channels, in_depth, in_height, in_width]."
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NDHWC"
|
||||
s: "NCDHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes second-order gradients of the maxpooling function."
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGrad"
|
||||
input_arg {
|
||||
@ -10848,12 +10975,175 @@ op {
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes gradients of the maxpooling function."
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGradGrad"
|
||||
input_arg {
|
||||
name: "orig_input"
|
||||
description: "The original input tensor."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "orig_output"
|
||||
description: "The original output tensor."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
description: "4-D. Gradients of gradients w.r.t. the input of `max_pool`."
|
||||
type_attr: "T"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Gradients of gradients w.r.t. the input to `max_pool`."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
description: "The size of the window for each dimension of the input tensor."
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
description: "The stride of the sliding window for each dimension of the\ninput tensor."
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
description: "The type of padding algorithm to use."
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "data_format"
|
||||
type: "string"
|
||||
default_value {
|
||||
s: "NHWC"
|
||||
}
|
||||
description: "Specify the data format of the input and output data. With the\ndefault format \"NHWC\", the data is stored in the order of:\n [batch, in_height, in_width, in_channels].\nAlternatively, the format could be \"NCHW\", the data storage order of:\n [batch, in_channels, in_height, in_width]."
|
||||
allowed_values {
|
||||
list {
|
||||
s: "NHWC"
|
||||
s: "NCHW"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes second-order gradients of the maxpooling function."
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGradGradWithArgmax"
|
||||
input_arg {
|
||||
name: "input"
|
||||
description: "The original input."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "grad"
|
||||
description: "4-D with shape `[batch, height, width, channels]`. Gradients w.r.t. the\ninput of `max_pool`."
|
||||
type_attr: "T"
|
||||
}
|
||||
input_arg {
|
||||
name: "argmax"
|
||||
description: "The indices of the maximum values chosen for each output of `max_pool`."
|
||||
type_attr: "Targmax"
|
||||
}
|
||||
output_arg {
|
||||
name: "output"
|
||||
description: "Gradients of gradients w.r.t. the input of `max_pool`."
|
||||
type_attr: "T"
|
||||
}
|
||||
attr {
|
||||
name: "ksize"
|
||||
type: "list(int)"
|
||||
description: "The size of the window for each dimension of the input tensor."
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "strides"
|
||||
type: "list(int)"
|
||||
description: "The stride of the sliding window for each dimension of the\ninput tensor."
|
||||
has_minimum: true
|
||||
minimum: 4
|
||||
}
|
||||
attr {
|
||||
name: "padding"
|
||||
type: "string"
|
||||
description: "The type of padding algorithm to use."
|
||||
allowed_values {
|
||||
list {
|
||||
s: "SAME"
|
||||
s: "VALID"
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "Targmax"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
}
|
||||
}
|
||||
}
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
}
|
||||
summary: "Computes second-order gradients of the maxpooling function."
|
||||
}
|
||||
op {
|
||||
name: "MaxPoolGradWithArgmax"
|
||||
input_arg {
|
||||
@ -10914,12 +11204,16 @@ op {
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
@ -10984,12 +11278,16 @@ op {
|
||||
attr {
|
||||
name: "T"
|
||||
type: "type"
|
||||
default_value {
|
||||
type: DT_FLOAT
|
||||
}
|
||||
allowed_values {
|
||||
list {
|
||||
type: DT_FLOAT
|
||||
type: DT_DOUBLE
|
||||
type: DT_INT32
|
||||
type: DT_INT64
|
||||
type: DT_UINT8
|
||||
type: DT_INT16
|
||||
type: DT_INT8
|
||||
type: DT_UINT16
|
||||
type: DT_HALF
|
||||
}
|
||||
}
|
||||
|
@ -69,10 +69,6 @@ int GetXCR0EAX() {
|
||||
// Structure for basic CPUID info
|
||||
class CPUIDInfo {
|
||||
public:
|
||||
string vendor_str;
|
||||
int family;
|
||||
int model_num;
|
||||
|
||||
CPUIDInfo()
|
||||
: have_adx_(0),
|
||||
have_aes_(0),
|
||||
@ -121,9 +117,9 @@ public:
|
||||
|
||||
// Get vendor string (issue CPUID with eax = 0)
|
||||
GETCPUID(eax, ebx, ecx, edx, 0, 0);
|
||||
cpuid->vendor_str.append(reinterpret_cast<char *>(&ebx), 4);
|
||||
cpuid->vendor_str.append(reinterpret_cast<char *>(&edx), 4);
|
||||
cpuid->vendor_str.append(reinterpret_cast<char *>(&ecx), 4);
|
||||
cpuid->vendor_str_.append(reinterpret_cast<char *>(&ebx), 4);
|
||||
cpuid->vendor_str_.append(reinterpret_cast<char *>(&edx), 4);
|
||||
cpuid->vendor_str_.append(reinterpret_cast<char *>(&ecx), 4);
|
||||
|
||||
// To get general information and extended features we send eax = 1 and
|
||||
// ecx = 0 to cpuid. The response is returned in eax, ebx, ecx and edx.
|
||||
@ -131,8 +127,8 @@ public:
|
||||
// Volume 2A: Instruction Set Reference, A-M CPUID).
|
||||
GETCPUID(eax, ebx, ecx, edx, 1, 0);
|
||||
|
||||
cpuid->model_num = static_cast<int>((eax >> 4) & 0xf);
|
||||
cpuid->family = static_cast<int>((eax >> 8) & 0xf);
|
||||
cpuid->model_num_ = static_cast<int>((eax >> 4) & 0xf);
|
||||
cpuid->family_ = static_cast<int>((eax >> 8) & 0xf);
|
||||
|
||||
cpuid->have_aes_ = (ecx >> 25) & 0x1;
|
||||
cpuid->have_cmov_ = (edx >> 15) & 0x1;
|
||||
@ -254,6 +250,10 @@ public:
|
||||
return false;
|
||||
}
|
||||
|
||||
string vendor_str() const { return vendor_str_; }
|
||||
int family() const { return family_; }
|
||||
int model_num() { return model_num_; }
|
||||
|
||||
private:
|
||||
int highest_eax_;
|
||||
int have_adx_ : 1;
|
||||
@ -293,6 +293,9 @@ public:
|
||||
int have_sse4_2_ : 1;
|
||||
int have_ssse3_ : 1;
|
||||
int have_hypervisor_ : 1;
|
||||
string vendor_str_;
|
||||
int family_;
|
||||
int model_num_;
|
||||
};
|
||||
|
||||
std::once_flag cpuid_once_flag;
|
||||
@ -318,7 +321,7 @@ bool TestCPUFeature(CPUFeature feature) {
|
||||
std::string CPUVendorIDString() {
|
||||
#ifdef PLATFORM_IS_X86
|
||||
InitCPUIDInfo();
|
||||
return cpuid->vendor_str;
|
||||
return cpuid->vendor_str();
|
||||
#else
|
||||
return "";
|
||||
#endif
|
||||
@ -327,7 +330,7 @@ std::string CPUVendorIDString() {
|
||||
int CPUFamily() {
|
||||
#ifdef PLATFORM_IS_X86
|
||||
InitCPUIDInfo();
|
||||
return cpuid->family;
|
||||
return cpuid->family();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
@ -336,7 +339,7 @@ int CPUFamily() {
|
||||
int CPUModelNum() {
|
||||
#ifdef PLATFORM_IS_X86
|
||||
InitCPUIDInfo();
|
||||
return cpuid->model_num;
|
||||
return cpuid->model_num();
|
||||
#else
|
||||
return 0;
|
||||
#endif
|
||||
|
@ -58,21 +58,20 @@ int NumSchedulableCPUs() {
|
||||
|
||||
void* AlignedMalloc(size_t size, int minimum_alignment) {
|
||||
#ifdef TENSORFLOW_USE_JEMALLOC
|
||||
void* ptr = NULL;
|
||||
// posix_memalign requires that the requested alignment be at least
|
||||
// sizeof(void*). In this case, fall back on malloc which should return
|
||||
// memory aligned to at least the size of a pointer.
|
||||
const int required_alignment = sizeof(void*);
|
||||
if (minimum_alignment < required_alignment) return Malloc(size);
|
||||
int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
|
||||
if (err != 0) {
|
||||
return NULL;
|
||||
}
|
||||
else {
|
||||
return ptr;
|
||||
}
|
||||
void* ptr = NULL;
|
||||
// posix_memalign requires that the requested alignment be at least
|
||||
// sizeof(void*). In this case, fall back on malloc which should return
|
||||
// memory aligned to at least the size of a pointer.
|
||||
const int required_alignment = sizeof(void*);
|
||||
if (minimum_alignment < required_alignment) return Malloc(size);
|
||||
int err = jemalloc_posix_memalign(&ptr, minimum_alignment, size);
|
||||
if (err != 0) {
|
||||
return NULL;
|
||||
} else {
|
||||
return ptr;
|
||||
}
|
||||
#else
|
||||
return _aligned_malloc(size, minimum_alignment);
|
||||
return _aligned_malloc(size, minimum_alignment);
|
||||
#endif
|
||||
}
|
||||
|
||||
|
@ -9,4 +9,8 @@ option java_package = "org.tensorflow.framework";
|
||||
message RewriterConfig {
|
||||
bool optimize_tensor_layout = 1;
|
||||
bool disable_model_pruning = 2;
|
||||
bool constant_folding = 3;
|
||||
// If non-empty, will use this as an alternative way to specify a list of
|
||||
// optimizations to turn on and the order of the optimizations.
|
||||
repeated string optimizers = 100;
|
||||
}
|
||||
|
@ -28,13 +28,18 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
// Intentionally do not check that versions match so that this routine can
|
||||
// be used for less brittle golden file tests.
|
||||
return EqualRepeatedNodeDef(actual.node(), expected.node(), diff, options);
|
||||
}
|
||||
|
||||
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
|
||||
const protobuf::RepeatedPtrField<NodeDef>& expected,
|
||||
string* diff, const EqualGraphDefOptions& options) {
|
||||
std::unordered_map<string, const NodeDef*> actual_index;
|
||||
for (const NodeDef& node : actual.node()) {
|
||||
for (const NodeDef& node : actual) {
|
||||
actual_index[node.name()] = &node;
|
||||
}
|
||||
|
||||
for (const NodeDef& expected_node : expected.node()) {
|
||||
for (const NodeDef& expected_node : expected) {
|
||||
auto actual_iter = actual_index.find(expected_node.name());
|
||||
if (actual_iter == actual_index.end()) {
|
||||
if (diff != nullptr) {
|
||||
@ -53,10 +58,9 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
|
||||
if (!actual_index.empty()) {
|
||||
if (diff != nullptr) {
|
||||
*diff = strings::StrCat("Found unexpected node '",
|
||||
SummarizeNodeDef(*actual_index.begin()->second),
|
||||
"' not in expected graph:\n",
|
||||
SummarizeGraphDef(expected));
|
||||
*diff =
|
||||
strings::StrCat("Found unexpected node '",
|
||||
SummarizeNodeDef(*actual_index.begin()->second), "'");
|
||||
}
|
||||
return false;
|
||||
}
|
||||
|
@ -18,6 +18,7 @@ limitations under the License.
|
||||
|
||||
#include "tensorflow/core/framework/graph.pb.h"
|
||||
#include "tensorflow/core/framework/graph_def_util.h"
|
||||
#include "tensorflow/core/platform/protobuf.h"
|
||||
#include "tensorflow/core/platform/types.h"
|
||||
|
||||
namespace tensorflow {
|
||||
@ -44,6 +45,14 @@ bool EqualGraphDef(const GraphDef& actual, const GraphDef& expected,
|
||||
bool EqualNodeDef(const NodeDef& actual, const NodeDef& expected, string* diff,
|
||||
const EqualGraphDefOptions& options = {});
|
||||
|
||||
// Determines if actual and expected are equal, ignoring ordering. If they're
|
||||
// different and diff != nullptr, *diff is set to an explanation of the
|
||||
// difference.
|
||||
bool EqualRepeatedNodeDef(const protobuf::RepeatedPtrField<NodeDef>& actual,
|
||||
const protobuf::RepeatedPtrField<NodeDef>& expected,
|
||||
string* diff,
|
||||
const EqualGraphDefOptions& options = {});
|
||||
|
||||
#define TF_EXPECT_GRAPH_EQ(expected, actual) \
|
||||
do { \
|
||||
string diff; \
|
||||
|
@ -47,8 +47,7 @@ class EqualGraphDefTest : public ::testing::Test {
|
||||
protected:
|
||||
EqualGraphDefTest()
|
||||
: e_(GraphDefBuilder::kFailImmediately),
|
||||
a_(GraphDefBuilder::kFailImmediately) {
|
||||
}
|
||||
a_(GraphDefBuilder::kFailImmediately) {}
|
||||
|
||||
bool Match() {
|
||||
GraphDef expected;
|
||||
@ -89,11 +88,7 @@ TEST_F(EqualGraphDefTest, ExtraNode) {
|
||||
Input(a_.opts().WithName("A"));
|
||||
Input(a_.opts().WithName("B"));
|
||||
EXPECT_FALSE(Match());
|
||||
EXPECT_EQ(strings::StrCat(
|
||||
"Found unexpected node 'B = Input[]()' not in expected graph:\n"
|
||||
"versions = producer: ",
|
||||
TF_GRAPH_DEF_VERSION, ";\n", "A = Input[]();\n"),
|
||||
diff_);
|
||||
EXPECT_EQ("Found unexpected node 'B = Input[]()'", diff_);
|
||||
}
|
||||
|
||||
TEST_F(EqualGraphDefTest, NodeOrder) {
|
||||
@ -169,21 +164,23 @@ TEST_F(EqualGraphDefTest, ControlInputOrder) {
|
||||
Node* b = Input(e_.opts().WithName("B"));
|
||||
Node* c = Input(e_.opts().WithName("C"));
|
||||
Node* d = Input(e_.opts().WithName("D"));
|
||||
Combine(a, a, e_.opts()
|
||||
.WithName("E")
|
||||
.WithControlInput(b)
|
||||
.WithControlInput(c)
|
||||
.WithControlInput(d));
|
||||
Combine(a, a,
|
||||
e_.opts()
|
||||
.WithName("E")
|
||||
.WithControlInput(b)
|
||||
.WithControlInput(c)
|
||||
.WithControlInput(d));
|
||||
|
||||
a = Input(a_.opts().WithName("A"));
|
||||
b = Input(a_.opts().WithName("B"));
|
||||
c = Input(a_.opts().WithName("C"));
|
||||
d = Input(a_.opts().WithName("D"));
|
||||
Combine(a, a, a_.opts()
|
||||
.WithName("E")
|
||||
.WithControlInput(c)
|
||||
.WithControlInput(d)
|
||||
.WithControlInput(b));
|
||||
Combine(a, a,
|
||||
a_.opts()
|
||||
.WithName("E")
|
||||
.WithControlInput(c)
|
||||
.WithControlInput(d)
|
||||
.WithControlInput(b));
|
||||
EXPECT_TRUE(Match()) << diff_;
|
||||
}
|
||||
|
||||
|
@ -17,8 +17,8 @@ limitations under the License.
|
||||
#define TENSORFLOW_CORE_UTIL_MKL_UTIL_H_
|
||||
#ifdef INTEL_MKL
|
||||
|
||||
#include <vector>
|
||||
#include <string>
|
||||
#include <vector>
|
||||
|
||||
#include "third_party/mkl/include/mkl_dnn.h"
|
||||
#include "third_party/mkl/include/mkl_dnn_types.h"
|
||||
@ -63,9 +63,7 @@ class MklShape {
|
||||
|
||||
void SetMklTensor(const bool isMklTensor) { isMklTensor_ = isMklTensor; }
|
||||
|
||||
void SetDimensions(const size_t dimension) {
|
||||
dimension_ = dimension;
|
||||
}
|
||||
void SetDimensions(const size_t dimension) { dimension_ = dimension; }
|
||||
|
||||
void SetMklLayout(const void* primitive, size_t resourceType) {
|
||||
CHECK_EQ(
|
||||
@ -408,8 +406,8 @@ static inline bool IsMklLayer(const std::string& op_name, DataType T) {
|
||||
// the type is float. Actually, we should query kernel registration and
|
||||
// find out if op is supported for type T. But there is no API to query
|
||||
// kernel registration using name and type.
|
||||
bool result = (kernel.find(kMklLayerLabelPattern) != string::npos) &&
|
||||
(T == DT_FLOAT);
|
||||
bool result =
|
||||
(kernel.find(kMklLayerLabelPattern) != string::npos) && (T == DT_FLOAT);
|
||||
if (result == true) {
|
||||
VLOG(1) << "mkl_layer_registry::" << op_name << " is " << kMklLayerLabel;
|
||||
}
|
||||
|
@ -28,8 +28,8 @@ import (
|
||||
"os"
|
||||
"path/filepath"
|
||||
|
||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
"github.com/tensorflow/tensorflow/tensorflow/go/op"
|
||||
tf "github.com/tensorflow/tensorflow/tensorflow/go"
|
||||
)
|
||||
|
||||
func Example() {
|
||||
|
@ -158,12 +158,12 @@ func makeOutputList(op *tf.Operation, start int, output string) ([]tf.Output, in
|
||||
`))
|
||||
|
||||
tmplOp = template.Must(template.New("op").Funcs(template.FuncMap{
|
||||
"MakeComment": makeComment,
|
||||
"GoType": goType,
|
||||
"CamelCase": camelCase,
|
||||
"Identifier": identifier,
|
||||
"IsListArg": isListArg,
|
||||
"IsListAttr": isListAttr,
|
||||
"MakeComment": makeComment,
|
||||
"GoType": goType,
|
||||
"CamelCase": camelCase,
|
||||
"Identifier": identifier,
|
||||
"IsListArg": isListArg,
|
||||
"IsListAttr": isListAttr,
|
||||
"StripLeadingColon": stripLeadingColon,
|
||||
}).Parse(`
|
||||
{{if .OptionalAttrs -}}
|
||||
|
File diff suppressed because it is too large
Load Diff
@ -266,7 +266,11 @@ class Estimator(object):
|
||||
checkpoint_path=checkpoint_path,
|
||||
name=name)
|
||||
|
||||
def predict(self, input_fn, predict_keys=None, hooks=None, checkpoint_path=None):
|
||||
def predict(self,
|
||||
input_fn,
|
||||
predict_keys=None,
|
||||
hooks=None,
|
||||
checkpoint_path=None):
|
||||
"""Returns predictions for given features.
|
||||
|
||||
Args:
|
||||
|
@ -627,7 +627,10 @@ class EstimatorPredictTest(test.TestCase):
|
||||
def test_no_trained_model_invalid_checkpoint_path(self):
|
||||
est = estimator.Estimator(model_fn=model_fn_global_step_incrementer)
|
||||
with self.assertRaises(ValueError):
|
||||
next(est.predict(dummy_input_fn, checkpoint_path=saver.latest_checkpoint("fakedir")))
|
||||
next(
|
||||
est.predict(
|
||||
dummy_input_fn,
|
||||
checkpoint_path=saver.latest_checkpoint('fakedir')))
|
||||
|
||||
def test_tensor_predictions(self):
|
||||
|
||||
@ -848,9 +851,12 @@ class EstimatorPredictTest(test.TestCase):
|
||||
est1 = estimator.Estimator(model_fn=_model_fn)
|
||||
est1.train(dummy_input_fn, steps=1)
|
||||
est2 = estimator.Estimator(model_fn=_model_fn, model_dir=est1.model_dir)
|
||||
self.assertEqual([32.], next(est2.predict(
|
||||
dummy_input_fn,
|
||||
checkpoint_path=saver.latest_checkpoint(est1.model_dir))))
|
||||
self.assertEqual(
|
||||
[32.],
|
||||
next(
|
||||
est2.predict(
|
||||
dummy_input_fn,
|
||||
checkpoint_path=saver.latest_checkpoint(est1.model_dir))))
|
||||
|
||||
def test_scaffold_is_used(self):
|
||||
|
||||
|
@ -20,9 +20,9 @@ from __future__ import print_function
|
||||
|
||||
import collections
|
||||
import random
|
||||
import types as tp
|
||||
import numpy as np
|
||||
import six
|
||||
import types as tp
|
||||
|
||||
from tensorflow.python.estimator.inputs.queues import feeding_queue_runner as fqr
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -245,8 +245,8 @@ class _GeneratorFeedFn(object):
|
||||
|
||||
def __call__(self):
|
||||
if self._num_epochs and self._epoch >= self._num_epochs:
|
||||
raise errors.OutOfRangeError(
|
||||
None, None, "Already emitted %s epochs." % self._epoch)
|
||||
raise errors.OutOfRangeError(None, None,
|
||||
"Already emitted %s epochs." % self._epoch)
|
||||
list_dict = {}
|
||||
list_dict_size = 0
|
||||
while list_dict_size < self._batch_size:
|
||||
@ -258,8 +258,9 @@ class _GeneratorFeedFn(object):
|
||||
data_row = next(self._iterator)
|
||||
for index, key in enumerate(self._keys):
|
||||
if key not in data_row.keys():
|
||||
raise KeyError('key mismatch between dicts emitted by GenFun'
|
||||
'Expected {} keys; got {}'.format( self._keys, data_row.keys()))
|
||||
raise KeyError("key mismatch between dicts emitted by GenFun"
|
||||
"Expected {} keys; got {}".format(
|
||||
self._keys, data_row.keys()))
|
||||
list_dict.setdefault(self._col_placeholders[index],
|
||||
list()).append(data_row[key])
|
||||
list_dict_size += 1
|
||||
|
@ -18,8 +18,8 @@ from __future__ import absolute_import
|
||||
from __future__ import division
|
||||
from __future__ import print_function
|
||||
|
||||
import numpy as np
|
||||
import sys
|
||||
import numpy as np
|
||||
|
||||
from tensorflow.python.framework import constant_op
|
||||
from tensorflow.python.framework import dtypes
|
||||
@ -48,13 +48,13 @@ class TensorUtilTest(test.TestCase):
|
||||
|
||||
def testFloatN(self):
|
||||
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0])
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "A \000\000A\240\000\000A\360\000\000"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } }
|
||||
@ -66,12 +66,12 @@ class TensorUtilTest(test.TestCase):
|
||||
|
||||
def testFloatTyped(self):
|
||||
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], dtype=dtypes.float32)
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "A \000\000A\240\000\000A\360\000\000"
|
||||
""", t)
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
@ -84,13 +84,13 @@ class TensorUtilTest(test.TestCase):
|
||||
|
||||
def testFloatTypeCoerce(self):
|
||||
t = tensor_util.make_tensor_proto([10, 20, 30], dtype=dtypes.float32)
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "A \000\000A\240\000\000A\360\000\000"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } }
|
||||
@ -103,13 +103,13 @@ class TensorUtilTest(test.TestCase):
|
||||
def testFloatTypeCoerceNdarray(self):
|
||||
arr = np.asarray([10, 20, 30], dtype="int")
|
||||
t = tensor_util.make_tensor_proto(arr, dtype=dtypes.float32)
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "A \000\000A\240\000\000A\360\000\000"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } }
|
||||
@ -121,13 +121,13 @@ class TensorUtilTest(test.TestCase):
|
||||
|
||||
def testFloatSizes(self):
|
||||
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[1, 3])
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||
tensor_content: "A \000\000A\240\000\000A\360\000\000"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||
@ -139,13 +139,13 @@ class TensorUtilTest(test.TestCase):
|
||||
|
||||
def testFloatSizes2(self):
|
||||
t = tensor_util.make_tensor_proto([10.0, 20.0, 30.0], shape=[3, 1])
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } dim { size: 1 } }
|
||||
tensor_content: "A \000\000A\240\000\000A\360\000\000"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_FLOAT
|
||||
tensor_shape { dim { size: 3 } dim { size: 1 } }
|
||||
@ -167,13 +167,13 @@ class TensorUtilTest(test.TestCase):
|
||||
def testFloatNpArrayFloat64(self):
|
||||
t = tensor_util.make_tensor_proto(
|
||||
np.array([[10.0, 20.0, 30.0]], dtype=np.float64))
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_DOUBLE
|
||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||
tensor_content: "@$\000\000\000\000\000\000@4\000\000\000\000\000\000@>\000\000\000\000\000\000"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_DOUBLE
|
||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||
@ -258,13 +258,13 @@ class TensorUtilTest(test.TestCase):
|
||||
|
||||
def testIntNDefaultType(self):
|
||||
t = tensor_util.make_tensor_proto([10, 20, 30, 40], shape=[2, 2])
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_INT32
|
||||
tensor_shape { dim { size: 2 } dim { size: 2 } }
|
||||
tensor_content: "\000\000\000\\n\000\000\000\024\000\000\000\036\000\000\000("
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_INT32
|
||||
tensor_shape { dim { size: 2 } dim { size: 2 } }
|
||||
@ -328,13 +328,13 @@ class TensorUtilTest(test.TestCase):
|
||||
def testLongN(self):
|
||||
t = tensor_util.make_tensor_proto(
|
||||
[10, 20, 30], shape=[1, 3], dtype=dtypes.int64)
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_INT64
|
||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||
tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_INT64
|
||||
tensor_shape { dim { size: 1 } dim { size: 3 } }
|
||||
@ -346,13 +346,13 @@ class TensorUtilTest(test.TestCase):
|
||||
|
||||
def testLongNpArray(self):
|
||||
t = tensor_util.make_tensor_proto(np.array([10, 20, 30]))
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_INT64
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "\000\000\000\000\000\000\000\\n\000\000\000\000\000\000\000\024\000\000\000\000\000\000\000\036"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_INT64
|
||||
tensor_shape { dim { size: 3 } }
|
||||
@ -367,13 +367,13 @@ class TensorUtilTest(test.TestCase):
|
||||
data = [(21,), (22,), (23,)]
|
||||
|
||||
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint32)
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_QINT32
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "\000\000\000\025\000\000\000\026\000\000\000\027"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_QINT32
|
||||
tensor_shape { dim { size: 3 } }
|
||||
@ -404,13 +404,13 @@ class TensorUtilTest(test.TestCase):
|
||||
self.assertAllEqual(np.array(data, dtype=a.dtype), a)
|
||||
|
||||
t = tensor_util.make_tensor_proto(data, dtype=dtypes.quint16)
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_QUINT16
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "\000\025\000\026\000\027"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_QUINT16
|
||||
tensor_shape { dim { size: 3 } }
|
||||
@ -421,13 +421,13 @@ class TensorUtilTest(test.TestCase):
|
||||
self.assertAllEqual(np.array(data, dtype=a.dtype), a)
|
||||
|
||||
t = tensor_util.make_tensor_proto(data, dtype=dtypes.qint16)
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
if sys.byteorder == "big":
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_QINT16
|
||||
tensor_shape { dim { size: 3 } }
|
||||
tensor_content: "\000\025\000\026\000\027"
|
||||
""", t)
|
||||
else:
|
||||
""", t)
|
||||
else:
|
||||
self.assertProtoEquals("""
|
||||
dtype: DT_QINT16
|
||||
tensor_shape { dim { size: 3 } }
|
||||
@ -669,7 +669,9 @@ class TensorUtilTest(test.TestCase):
|
||||
self.assertFalse(tensor_util.ShapeEquals(t, [4]))
|
||||
|
||||
def testMockArray(self):
|
||||
|
||||
class MockArray(object):
|
||||
|
||||
def __init__(self, array):
|
||||
self.array = array
|
||||
|
||||
|
@ -261,7 +261,7 @@ class PoolingTest(test.TestCase):
|
||||
padding=padding,
|
||||
data_format=data_format,
|
||||
name=func_name)
|
||||
t_g = gradients_impl.gradients(t ** 2, input_tensor)[0]
|
||||
t_g = gradients_impl.gradients(t**2, input_tensor)[0]
|
||||
|
||||
err_g = gradient_checker.compute_gradient_error(
|
||||
input_tensor,
|
||||
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user