Merge commit for internal changes
This commit is contained in:
commit
9dd8e7aec9
3
.gitignore
vendored
3
.gitignore
vendored
@ -7,11 +7,8 @@ node_modules
|
|||||||
/bazel_pip
|
/bazel_pip
|
||||||
/third_party/eigen3/mkl_include
|
/third_party/eigen3/mkl_include
|
||||||
/third_party/mkl/*
|
/third_party/mkl/*
|
||||||
/third_party/py/numpy/numpy_include
|
|
||||||
/tools/python_bin_path.sh
|
/tools/python_bin_path.sh
|
||||||
/tools/git/gen
|
/tools/git/gen
|
||||||
/util/python/python_include
|
|
||||||
/util/python/python_lib
|
|
||||||
/pip_test
|
/pip_test
|
||||||
/_python_build
|
/_python_build
|
||||||
*.pyc
|
*.pyc
|
||||||
|
@ -263,6 +263,7 @@ filegroup(
|
|||||||
"//tensorflow/contrib/seq2seq:all_files",
|
"//tensorflow/contrib/seq2seq:all_files",
|
||||||
"//tensorflow/contrib/session_bundle:all_files",
|
"//tensorflow/contrib/session_bundle:all_files",
|
||||||
"//tensorflow/contrib/session_bundle/example:all_files",
|
"//tensorflow/contrib/session_bundle/example:all_files",
|
||||||
|
"//tensorflow/contrib/signal:all_files",
|
||||||
"//tensorflow/contrib/slim:all_files",
|
"//tensorflow/contrib/slim:all_files",
|
||||||
"//tensorflow/contrib/slim/python/slim/data:all_files",
|
"//tensorflow/contrib/slim/python/slim/data:all_files",
|
||||||
"//tensorflow/contrib/slim/python/slim/nets:all_files",
|
"//tensorflow/contrib/slim/python/slim/nets:all_files",
|
||||||
@ -326,6 +327,48 @@ filegroup(
|
|||||||
"//tensorflow/tensorboard/backend:all_files",
|
"//tensorflow/tensorboard/backend:all_files",
|
||||||
"//tensorflow/tensorboard/backend/event_processing:all_files",
|
"//tensorflow/tensorboard/backend/event_processing:all_files",
|
||||||
"//tensorflow/tensorboard/components:all_files",
|
"//tensorflow/tensorboard/components:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_audio_dashboard:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_audio_dashboard/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_backend:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_backend_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_color_scale:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_color_scale/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_color_scale_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_distribution_dashboard:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_distribution_dashboard/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_globals:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_globals_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_graph_common:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_histogram_dashboard:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_histogram_dashboard/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_image_dashboard:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_image_dashboard/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_scalar_dashboard:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_scalar_dashboard/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_storage:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_storage_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_text_dashboard:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/tf_text_dashboard/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_data_summary:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_distribution_chart:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_distribution_chart/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_distribution_chart_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_histogram_timeseries:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_histogram_timeseries/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_histogram_timeseries_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_line_chart:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_line_chart/demo:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_line_chart_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_projector:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_projector_d3v4:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting/test:all_files",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting_d3v4:all_files",
|
||||||
"//tensorflow/tensorboard/lib:all_files",
|
"//tensorflow/tensorboard/lib:all_files",
|
||||||
"//tensorflow/tensorboard/plugins:all_files",
|
"//tensorflow/tensorboard/plugins:all_files",
|
||||||
"//tensorflow/tensorboard/plugins/projector:all_files",
|
"//tensorflow/tensorboard/plugins/projector:all_files",
|
||||||
|
@ -28,11 +28,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/compiler/xla/types.h"
|
#include "tensorflow/compiler/xla/types.h"
|
||||||
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
#include "tensorflow/compiler/xla/xla_data.pb.h"
|
||||||
|
|
||||||
namespace op = xla::testing::opcode_matchers;
|
|
||||||
|
|
||||||
namespace xla {
|
namespace xla {
|
||||||
namespace {
|
namespace {
|
||||||
|
|
||||||
|
namespace op = xla::testing::opcode_matchers;
|
||||||
|
|
||||||
|
using ::testing::_;
|
||||||
|
|
||||||
class HloRematerializationTest : public HloTestBase {
|
class HloRematerializationTest : public HloTestBase {
|
||||||
protected:
|
protected:
|
||||||
// Creates and returns a computation which can benefit from
|
// Creates and returns a computation which can benefit from
|
||||||
@ -145,11 +147,9 @@ TEST_F(HloRematerializationTest, SingleComputation) {
|
|||||||
// Find and save the original broadcast instruction which should be
|
// Find and save the original broadcast instruction which should be
|
||||||
// rematerialized.
|
// rematerialized.
|
||||||
const HloInstruction* slice = computation->root_instruction();
|
const HloInstruction* slice = computation->root_instruction();
|
||||||
ASSERT_EQ(HloOpcode::kSlice, slice->opcode());
|
ASSERT_THAT(slice, op::Slice(op::Concatenate(op::Broadcast(_), _)));
|
||||||
const HloInstruction* concat = slice->operand(0);
|
const HloInstruction* concat = slice->operand(0);
|
||||||
ASSERT_EQ(HloOpcode::kConcatenate, concat->opcode());
|
|
||||||
const HloInstruction* bcast = concat->operand(0);
|
const HloInstruction* bcast = concat->operand(0);
|
||||||
ASSERT_EQ(HloOpcode::kBroadcast, bcast->opcode());
|
|
||||||
|
|
||||||
SequentialHloOrdering::HloModuleSequence sequence;
|
SequentialHloOrdering::HloModuleSequence sequence;
|
||||||
// Computation requires 16KB without rematerialization, but uses only 12KB
|
// Computation requires 16KB without rematerialization, but uses only 12KB
|
||||||
@ -165,8 +165,7 @@ TEST_F(HloRematerializationTest, SingleComputation) {
|
|||||||
|
|
||||||
// The broadcast should have been rematerialized.
|
// The broadcast should have been rematerialized.
|
||||||
const HloInstruction* remat_bcast = concat->operand(0);
|
const HloInstruction* remat_bcast = concat->operand(0);
|
||||||
EXPECT_EQ(HloOpcode::kBroadcast, remat_bcast->opcode());
|
EXPECT_THAT(remat_bcast, op::Broadcast(::testing::Ne(bcast)));
|
||||||
EXPECT_NE(bcast, remat_bcast);
|
|
||||||
|
|
||||||
// The rematerialized broadcast should be immediate before the concat in the
|
// The rematerialized broadcast should be immediate before the concat in the
|
||||||
// sequence.
|
// sequence.
|
||||||
|
@ -68,9 +68,8 @@ void CleanNodeName(string* name) {
|
|||||||
}
|
}
|
||||||
|
|
||||||
Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
|
Status HloTfGraphBuilder::AddComputation(const HloComputation& computation) {
|
||||||
LOG(INFO) << "Adding computation " << computation.name();
|
VLOG(2) << "Adding computation " << computation.name();
|
||||||
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
|
for (auto embedded : computation.MakeEmbeddedComputationsList()) {
|
||||||
LOG(INFO) << "Adding embedded computation " << embedded->name();
|
|
||||||
for (auto& instruction : embedded->instructions()) {
|
for (auto& instruction : embedded->instructions()) {
|
||||||
TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
|
TF_RETURN_IF_ERROR(AddInstruction(instruction.get()));
|
||||||
}
|
}
|
||||||
|
@ -85,7 +85,7 @@ def _init_clusters_random(data, num_clusters, random_seed):
|
|||||||
maxval=math_ops.cast(num_data, dtypes.int64),
|
maxval=math_ops.cast(num_data, dtypes.int64),
|
||||||
seed=random_seed,
|
seed=random_seed,
|
||||||
dtype=dtypes.int64)
|
dtype=dtypes.int64)
|
||||||
indices = indices % math_ops.cast(num_data, dtypes.int64)
|
indices %= math_ops.cast(num_data, dtypes.int64)
|
||||||
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
|
clusters_init = embedding_lookup(data, indices, partition_strategy='div')
|
||||||
return clusters_init
|
return clusters_init
|
||||||
|
|
||||||
|
@ -35,8 +35,8 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
|
|
||||||
def testGrid2BasicLSTMCell(self):
|
def testGrid2BasicLSTMCell(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.2)) as root_scope:
|
'root', initializer=init_ops.constant_initializer(0.2)) as root_scope:
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
||||||
@ -51,21 +51,22 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x:
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
np.array([[1., 1., 1.]]),
|
||||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||||
|
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||||
self.assertEqual(res_s[1].c.shape, (1, 2))
|
self.assertEqual(res_s[1].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g, ([[0.36617181, 0.36617181]], ))
|
self.assertAllClose(res_g, ([[0.36617181, 0.36617181]],))
|
||||||
self.assertAllClose(res_s, (([[0.71053141, 0.71053141]],
|
self.assertAllClose(
|
||||||
[[0.36617181, 0.36617181]]),
|
res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
|
||||||
([[0.72320831, 0.80555487]],
|
([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
|
||||||
[[0.39102408, 0.42150158]])))
|
|
||||||
|
|
||||||
# emulate a loop through the input sequence,
|
# emulate a loop through the input sequence,
|
||||||
# where we call cell() multiple times
|
# where we call cell() multiple times
|
||||||
@ -78,22 +79,22 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s2[1].h.get_shape(), (1, 2))
|
self.assertEqual(s2[1].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
res_g2, res_s2 = sess.run([g2, s2],
|
res_g2, res_s2 = sess.run([g2, s2],
|
||||||
{x: np.array([[2., 2., 2.]]), m: res_s})
|
{x: np.array([[2., 2., 2.]]),
|
||||||
|
m: res_s})
|
||||||
self.assertEqual(res_g2[0].shape, (1, 2))
|
self.assertEqual(res_g2[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s2[0].c.shape, (1, 2))
|
self.assertEqual(res_s2[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s2[0].h.shape, (1, 2))
|
self.assertEqual(res_s2[0].h.shape, (1, 2))
|
||||||
self.assertEqual(res_s2[1].c.shape, (1, 2))
|
self.assertEqual(res_s2[1].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s2[1].h.shape, (1, 2))
|
self.assertEqual(res_s2[1].h.shape, (1, 2))
|
||||||
self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]])
|
self.assertAllClose(res_g2[0], [[0.58847463, 0.58847463]])
|
||||||
self.assertAllClose(res_s2, (([[1.40469193, 1.40469193]],
|
self.assertAllClose(
|
||||||
[[0.58847463, 0.58847463]]),
|
res_s2, (([[1.40469193, 1.40469193]], [[0.58847463, 0.58847463]]),
|
||||||
([[0.97726452, 1.04626071]],
|
([[0.97726452, 1.04626071]], [[0.4927212, 0.51137757]])))
|
||||||
[[0.4927212, 0.51137757]])))
|
|
||||||
|
|
||||||
def testGrid2BasicLSTMCellTied(self):
|
def testGrid2BasicLSTMCellTied(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
'root', initializer=init_ops.constant_initializer(0.2)):
|
'root', initializer=init_ops.constant_initializer(0.2)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
||||||
@ -108,10 +109,12 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x:
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
np.array([[1., 1., 1.]]),
|
||||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||||
|
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||||
@ -119,29 +122,27 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]])
|
self.assertAllClose(res_g[0], [[0.36617181, 0.36617181]])
|
||||||
self.assertAllClose(res_s, (([[0.71053141, 0.71053141]],
|
self.assertAllClose(
|
||||||
[[0.36617181, 0.36617181]]),
|
res_s, (([[0.71053141, 0.71053141]], [[0.36617181, 0.36617181]]),
|
||||||
([[0.72320831, 0.80555487]],
|
([[0.72320831, 0.80555487]], [[0.39102408, 0.42150158]])))
|
||||||
[[0.39102408, 0.42150158]])))
|
|
||||||
|
|
||||||
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s})
|
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1., 1.]]), m: res_s})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]])
|
self.assertAllClose(res_g[0], [[0.36703536, 0.36703536]])
|
||||||
self.assertAllClose(res_s, (([[0.71200621, 0.71200621]],
|
self.assertAllClose(
|
||||||
[[0.36703536, 0.36703536]]),
|
res_s, (([[0.71200621, 0.71200621]], [[0.36703536, 0.36703536]]),
|
||||||
([[0.80941606, 0.87550586]],
|
([[0.80941606, 0.87550586]], [[0.40108523, 0.42199609]])))
|
||||||
[[0.40108523, 0.42199609]])))
|
|
||||||
|
|
||||||
def testGrid2BasicLSTMCellWithRelu(self):
|
def testGrid2BasicLSTMCellWithRelu(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.2)):
|
'root', initializer=init_ops.constant_initializer(0.2)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||||
cell = grid_rnn_cell.Grid2BasicLSTMCell(
|
cell = grid_rnn_cell.Grid2BasicLSTMCell(
|
||||||
2, tied=False, non_recurrent_fn=nn_ops.relu)
|
2, tied=False, non_recurrent_fn=nn_ops.relu)
|
||||||
self.assertEqual(cell.state_size, ((2, 2), ))
|
self.assertEqual(cell.state_size, ((2, 2),))
|
||||||
|
|
||||||
g, s = cell(x, m)
|
g, s = cell(x, m)
|
||||||
self.assertEqual(g[0].get_shape(), (1, 2))
|
self.assertEqual(g[0].get_shape(), (1, 2))
|
||||||
@ -149,21 +150,22 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x: np.array([[1., 1., 1.]]),
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]])
|
self.assertAllClose(res_g[0], [[0.31667367, 0.31667367]])
|
||||||
self.assertAllClose(res_s, (([[0.29530135, 0.37520045]],
|
self.assertAllClose(res_s, (([[0.29530135, 0.37520045]],
|
||||||
[[0.17044567, 0.21292259]]), ))
|
[[0.17044567, 0.21292259]]),))
|
||||||
|
|
||||||
"""LSTMCell
|
"""LSTMCell
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def testGrid2LSTMCell(self):
|
def testGrid2LSTMCell(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
||||||
@ -178,10 +180,12 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x:
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
np.array([[1., 1., 1.]]),
|
||||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||||
|
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||||
@ -189,15 +193,14 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
|
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
|
||||||
self.assertAllClose(res_s, (([[2.41515064, 2.41515064]],
|
self.assertAllClose(
|
||||||
[[0.95686918, 0.95686918]]),
|
res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
|
||||||
([[1.38917875, 1.49043763]],
|
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
|
||||||
[[0.83884692, 0.86036491]])))
|
|
||||||
|
|
||||||
def testGrid2LSTMCellTied(self):
|
def testGrid2LSTMCellTied(self):
|
||||||
with self.test_session(use_gpu=False) as sess:
|
with self.test_session(use_gpu=False) as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])))
|
||||||
@ -212,10 +215,12 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
self.assertEqual(s[1].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x:
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
np.array([[1., 1., 1.]]),
|
||||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))})
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||||
|
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])))
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||||
@ -223,15 +228,14 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(res_s[1].h.shape, (1, 2))
|
self.assertEqual(res_s[1].h.shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
|
self.assertAllClose(res_g[0], [[0.95686918, 0.95686918]])
|
||||||
self.assertAllClose(res_s, (([[2.41515064, 2.41515064]],
|
self.assertAllClose(
|
||||||
[[0.95686918, 0.95686918]]),
|
res_s, (([[2.41515064, 2.41515064]], [[0.95686918, 0.95686918]]),
|
||||||
([[1.38917875, 1.49043763]],
|
([[1.38917875, 1.49043763]], [[0.83884692, 0.86036491]])))
|
||||||
[[0.83884692, 0.86036491]])))
|
|
||||||
|
|
||||||
def testGrid2LSTMCellWithRelu(self):
|
def testGrid2LSTMCellWithRelu(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||||
cell = grid_rnn_cell.Grid2LSTMCell(
|
cell = grid_rnn_cell.Grid2LSTMCell(
|
||||||
@ -244,21 +248,22 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x: np.array([[1., 1., 1.]]),
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]])
|
self.assertAllClose(res_g[0], [[2.1831727, 2.1831727]])
|
||||||
self.assertAllClose(res_s, (([[0.92270052, 1.02325559]],
|
self.assertAllClose(res_s, (([[0.92270052, 1.02325559]],
|
||||||
[[0.66159075, 0.70475441]]), ))
|
[[0.66159075, 0.70475441]]),))
|
||||||
|
|
||||||
"""RNNCell
|
"""RNNCell
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def testGrid2BasicRNNCell(self):
|
def testGrid2BasicRNNCell(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([2, 2])
|
x = array_ops.zeros([2, 2])
|
||||||
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
|
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
|
||||||
cell = grid_rnn_cell.Grid2BasicRNNCell(2)
|
cell = grid_rnn_cell.Grid2BasicRNNCell(2)
|
||||||
@ -270,26 +275,26 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[1].get_shape(), (2, 2))
|
self.assertEqual(s[1].get_shape(), (2, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
x:
|
||||||
m: (np.array([[0.1, 0.1], [0.2, 0.2]]),
|
np.array([[1., 1.], [2., 2.]]),
|
||||||
np.array([[0.1, 0.1], [0.2, 0.2]]))})
|
m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
|
||||||
|
[0.2, 0.2]]))
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (2, 2))
|
self.assertEqual(res_g[0].shape, (2, 2))
|
||||||
self.assertEqual(res_s[0].shape, (2, 2))
|
self.assertEqual(res_s[0].shape, (2, 2))
|
||||||
self.assertEqual(res_s[1].shape, (2, 2))
|
self.assertEqual(res_s[1].shape, (2, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
|
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
|
||||||
[0.99480951, 0.99480951]], ))
|
[0.99480951, 0.99480951]],))
|
||||||
self.assertAllClose(res_s,
|
self.assertAllClose(
|
||||||
([[0.94685763, 0.94685763],
|
res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
|
||||||
[0.99480951, 0.99480951]],
|
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
|
||||||
[[0.80049908, 0.80049908],
|
|
||||||
[0.97574311, 0.97574311]]))
|
|
||||||
|
|
||||||
def testGrid2BasicRNNCellTied(self):
|
def testGrid2BasicRNNCellTied(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([2, 2])
|
x = array_ops.zeros([2, 2])
|
||||||
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
|
m = (array_ops.zeros([2, 2]), array_ops.zeros([2, 2]))
|
||||||
cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
|
cell = grid_rnn_cell.Grid2BasicRNNCell(2, tied=True)
|
||||||
@ -301,55 +306,55 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[1].get_shape(), (2, 2))
|
self.assertEqual(s[1].get_shape(), (2, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1.], [2., 2.]]),
|
x:
|
||||||
m: (np.array([[0.1, 0.1], [0.2, 0.2]]),
|
np.array([[1., 1.], [2., 2.]]),
|
||||||
np.array([[0.1, 0.1], [0.2, 0.2]]))})
|
m: (np.array([[0.1, 0.1], [0.2, 0.2]]), np.array([[0.1, 0.1],
|
||||||
|
[0.2, 0.2]]))
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (2, 2))
|
self.assertEqual(res_g[0].shape, (2, 2))
|
||||||
self.assertEqual(res_s[0].shape, (2, 2))
|
self.assertEqual(res_s[0].shape, (2, 2))
|
||||||
self.assertEqual(res_s[1].shape, (2, 2))
|
self.assertEqual(res_s[1].shape, (2, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
|
self.assertAllClose(res_g, ([[0.94685763, 0.94685763],
|
||||||
[0.99480951, 0.99480951]], ))
|
[0.99480951, 0.99480951]],))
|
||||||
self.assertAllClose(res_s,
|
self.assertAllClose(
|
||||||
([[0.94685763, 0.94685763],
|
res_s, ([[0.94685763, 0.94685763], [0.99480951, 0.99480951]],
|
||||||
[0.99480951, 0.99480951]],
|
[[0.80049908, 0.80049908], [0.97574311, 0.97574311]]))
|
||||||
[[0.80049908, 0.80049908],
|
|
||||||
[0.97574311, 0.97574311]]))
|
|
||||||
|
|
||||||
def testGrid2BasicRNNCellWithRelu(self):
|
def testGrid2BasicRNNCellWithRelu(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = (array_ops.zeros([1, 2]), )
|
m = (array_ops.zeros([1, 2]),)
|
||||||
cell = grid_rnn_cell.Grid2BasicRNNCell(
|
cell = grid_rnn_cell.Grid2BasicRNNCell(2, non_recurrent_fn=nn_ops.relu)
|
||||||
2, non_recurrent_fn=nn_ops.relu)
|
self.assertEqual(cell.state_size, (2,))
|
||||||
self.assertEqual(cell.state_size, (2, ))
|
|
||||||
|
|
||||||
g, s = cell(x, m)
|
g, s = cell(x, m)
|
||||||
self.assertEqual(g[0].get_shape(), (1, 2))
|
self.assertEqual(g[0].get_shape(), (1, 2))
|
||||||
self.assertEqual(s[0].get_shape(), (1, 2))
|
self.assertEqual(s[0].get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run([g, s], {x: np.array([[1., 1.]]),
|
res_g, res_s = sess.run(
|
||||||
m: np.array([[0.1, 0.1]])})
|
[g, s], {x: np.array([[1., 1.]]),
|
||||||
|
m: np.array([[0.1, 0.1]])})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].shape, (1, 2))
|
self.assertEqual(res_s[0].shape, (1, 2))
|
||||||
self.assertAllClose(res_g, ([[1.80049896, 1.80049896]], ))
|
self.assertAllClose(res_g, ([[1.80049896, 1.80049896]],))
|
||||||
self.assertAllClose(res_s, ([[0.80049896, 0.80049896]], ))
|
self.assertAllClose(res_s, ([[0.80049896, 0.80049896]],))
|
||||||
|
|
||||||
"""1-LSTM
|
"""1-LSTM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def testGrid1LSTMCell(self):
|
def testGrid1LSTMCell(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)) as root_scope:
|
'root', initializer=init_ops.constant_initializer(0.5)) as root_scope:
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||||
cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True)
|
cell = grid_rnn_cell.Grid1LSTMCell(2, use_peepholes=True)
|
||||||
self.assertEqual(cell.state_size, ((2, 2), ))
|
self.assertEqual(cell.state_size, ((2, 2),))
|
||||||
|
|
||||||
g, s = cell(x, m)
|
g, s = cell(x, m)
|
||||||
self.assertEqual(g[0].get_shape(), (1, 2))
|
self.assertEqual(g[0].get_shape(), (1, 2))
|
||||||
@ -357,17 +362,17 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x: np.array([[1., 1., 1.]]),
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])), )})
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),)
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g, ([[0.91287315, 0.91287315]], ))
|
self.assertAllClose(res_g, ([[0.91287315, 0.91287315]],))
|
||||||
self.assertAllClose(res_s,
|
self.assertAllClose(res_s, (([[2.26285243, 2.26285243]],
|
||||||
(([[2.26285243, 2.26285243]],
|
[[0.91287315, 0.91287315]]),))
|
||||||
[[0.91287315, 0.91287315]]), ))
|
|
||||||
|
|
||||||
root_scope.reuse_variables()
|
root_scope.reuse_variables()
|
||||||
|
|
||||||
@ -383,10 +388,9 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(res_s2[0].c.shape, (1, 2))
|
self.assertEqual(res_s2[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s2[0].h.shape, (1, 2))
|
self.assertEqual(res_s2[0].h.shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]], ))
|
self.assertAllClose(res_g2, ([[0.9032144, 0.9032144]],))
|
||||||
self.assertAllClose(res_s2,
|
self.assertAllClose(res_s2, (([[2.79966092, 2.79966092]],
|
||||||
(([[2.79966092, 2.79966092]],
|
[[0.9032144, 0.9032144]]),))
|
||||||
[[0.9032144, 0.9032144]]), ))
|
|
||||||
|
|
||||||
g3, s3 = cell(x2, m)
|
g3, s3 = cell(x2, m)
|
||||||
self.assertEqual(g3[0].get_shape(), (1, 2))
|
self.assertEqual(g3[0].get_shape(), (1, 2))
|
||||||
@ -398,18 +402,17 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(res_g3[0].shape, (1, 2))
|
self.assertEqual(res_g3[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s3[0].c.shape, (1, 2))
|
self.assertEqual(res_s3[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s3[0].h.shape, (1, 2))
|
self.assertEqual(res_s3[0].h.shape, (1, 2))
|
||||||
self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]], ))
|
self.assertAllClose(res_g3, ([[0.92727238, 0.92727238]],))
|
||||||
self.assertAllClose(res_s3,
|
self.assertAllClose(res_s3, (([[3.3529923, 3.3529923]],
|
||||||
(([[3.3529923, 3.3529923]],
|
[[0.92727238, 0.92727238]]),))
|
||||||
[[0.92727238, 0.92727238]]), ))
|
|
||||||
|
|
||||||
"""3-LSTM
|
"""3-LSTM
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def testGrid3LSTMCell(self):
|
def testGrid3LSTMCell(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||||
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
(array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),
|
||||||
@ -427,11 +430,13 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[2].h.get_shape(), (1, 2))
|
self.assertEqual(s[2].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x:
|
||||||
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
np.array([[1., 1., 1.]]),
|
||||||
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])),
|
m: ((np.array([[0.1, 0.2]]), np.array([[0.3, 0.4]])),
|
||||||
(np.array([[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))})
|
(np.array([[0.5, 0.6]]), np.array([[0.7, 0.8]])), (np.array(
|
||||||
|
[[-0.1, -0.2]]), np.array([[-0.3, -0.4]])))
|
||||||
|
})
|
||||||
self.assertEqual(res_g[0].shape, (1, 2))
|
self.assertEqual(res_g[0].shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||||
@ -440,21 +445,19 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(res_s[2].c.shape, (1, 2))
|
self.assertEqual(res_s[2].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[2].h.shape, (1, 2))
|
self.assertEqual(res_s[2].h.shape, (1, 2))
|
||||||
|
|
||||||
self.assertAllClose(res_g, ([[0.96892911, 0.96892911]], ))
|
self.assertAllClose(res_g, ([[0.96892911, 0.96892911]],))
|
||||||
self.assertAllClose(res_s, (([[2.45227885, 2.45227885]],
|
self.assertAllClose(
|
||||||
[[0.96892911, 0.96892911]]),
|
res_s, (([[2.45227885, 2.45227885]], [[0.96892911, 0.96892911]]),
|
||||||
([[1.33592629, 1.4373529]],
|
([[1.33592629, 1.4373529]], [[0.80867189, 0.83247656]]),
|
||||||
[[0.80867189, 0.83247656]]),
|
([[0.7317788, 0.63205892]], [[0.56548983, 0.50446129]])))
|
||||||
([[0.7317788, 0.63205892]],
|
|
||||||
[[0.56548983, 0.50446129]])))
|
|
||||||
|
|
||||||
"""Edge cases
|
"""Edge cases
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def testGridRNNEdgeCasesLikeRelu(self):
|
def testGridRNNEdgeCasesLikeRelu(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([3, 2])
|
x = array_ops.zeros([3, 2])
|
||||||
m = ()
|
m = ()
|
||||||
|
|
||||||
@ -471,18 +474,18 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s, ())
|
self.assertEqual(s, ())
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s],
|
||||||
[g, s], {x: np.array([[1., -1.], [-2, 1], [2, -1]])})
|
{x: np.array([[1., -1.], [-2, 1], [2, -1]])})
|
||||||
self.assertEqual(res_g[0].shape, (3, 2))
|
self.assertEqual(res_g[0].shape, (3, 2))
|
||||||
self.assertEqual(res_s, ())
|
self.assertEqual(res_s, ())
|
||||||
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]], ))
|
self.assertAllClose(res_g, ([[0, 0], [0, 0], [0.5, 0.5]],))
|
||||||
|
|
||||||
def testGridRNNEdgeCasesNoOutput(self):
|
def testGridRNNEdgeCasesNoOutput(self):
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 2])
|
x = array_ops.zeros([1, 2])
|
||||||
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])), )
|
m = ((array_ops.zeros([1, 2]), array_ops.zeros([1, 2])),)
|
||||||
|
|
||||||
# This cell produces no output
|
# This cell produces no output
|
||||||
cell = grid_rnn_cell.GridRNNCell(
|
cell = grid_rnn_cell.GridRNNCell(
|
||||||
@ -498,9 +501,10 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
self.assertEqual(s[0].h.get_shape(), (1, 2))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res_g, res_s = sess.run(
|
res_g, res_s = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1.]]),
|
x: np.array([[1., 1.]]),
|
||||||
m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])), )})
|
m: ((np.array([[0.1, 0.1]]), np.array([[0.1, 0.1]])),)
|
||||||
|
})
|
||||||
self.assertEqual(res_g, ())
|
self.assertEqual(res_g, ())
|
||||||
self.assertEqual(res_s[0].c.shape, (1, 2))
|
self.assertEqual(res_s[0].c.shape, (1, 2))
|
||||||
self.assertEqual(res_s[0].h.shape, (1, 2))
|
self.assertEqual(res_s[0].h.shape, (1, 2))
|
||||||
@ -561,8 +565,9 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
cell = grid_rnn_cell.Grid2LSTMCell(
|
cell = grid_rnn_cell.Grid2LSTMCell(
|
||||||
num_units=num_units, non_recurrent_fn=nn_ops.relu)
|
num_units=num_units, non_recurrent_fn=nn_ops.relu)
|
||||||
|
|
||||||
inputs = max_length * [array_ops.placeholder(
|
inputs = max_length * [
|
||||||
dtypes.float32, shape=(batch_size, input_size))]
|
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||||
|
]
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
@ -600,8 +605,9 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
cell = grid_rnn_cell.Grid3LSTMCell(
|
cell = grid_rnn_cell.Grid3LSTMCell(
|
||||||
num_units=num_units, non_recurrent_fn=nn_ops.relu)
|
num_units=num_units, non_recurrent_fn=nn_ops.relu)
|
||||||
|
|
||||||
inputs = max_length * [array_ops.placeholder(
|
inputs = max_length * [
|
||||||
dtypes.float32, shape=(batch_size, input_size))]
|
array_ops.placeholder(dtypes.float32, shape=(batch_size, input_size))
|
||||||
|
]
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
|
|
||||||
@ -671,19 +677,17 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertTrue(np.all(np.isfinite(v)))
|
self.assertTrue(np.all(np.isfinite(v)))
|
||||||
|
|
||||||
def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self):
|
def testGrid2LSTMCellWithRNNAndDynamicBatchSize(self):
|
||||||
"""Test for #4296
|
"""Test for #4296."""
|
||||||
"""
|
|
||||||
input_size = 5
|
input_size = 5
|
||||||
max_length = 6 # unrolled up to this length
|
max_length = 6 # unrolled up to this length
|
||||||
num_units = 2
|
num_units = 2
|
||||||
|
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)
|
cell = grid_rnn_cell.Grid2LSTMCell(num_units=num_units)
|
||||||
|
|
||||||
inputs = max_length * [
|
inputs = max_length * [
|
||||||
array_ops.placeholder(
|
array_ops.placeholder(dtypes.float32, shape=(None, input_size))
|
||||||
dtypes.float32, shape=(None, input_size))
|
|
||||||
]
|
]
|
||||||
|
|
||||||
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
outputs, state = core_rnn.static_rnn(cell, inputs, dtype=dtypes.float32)
|
||||||
@ -700,8 +704,7 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
sess.run(variables.global_variables_initializer())
|
sess.run(variables.global_variables_initializer())
|
||||||
|
|
||||||
input_value = np.ones((3, input_size))
|
input_value = np.ones((3, input_size))
|
||||||
values = sess.run(outputs + [state],
|
values = sess.run(outputs + [state], feed_dict={inputs[0]: input_value})
|
||||||
feed_dict={inputs[0]: input_value})
|
|
||||||
for tp in values[:-1]:
|
for tp in values[:-1]:
|
||||||
for v in tp:
|
for v in tp:
|
||||||
self.assertTrue(np.all(np.isfinite(v)))
|
self.assertTrue(np.all(np.isfinite(v)))
|
||||||
@ -710,18 +713,15 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
for v in st:
|
for v in st:
|
||||||
self.assertTrue(np.all(np.isfinite(v)))
|
self.assertTrue(np.all(np.isfinite(v)))
|
||||||
|
|
||||||
|
|
||||||
def testGrid2LSTMCellLegacy(self):
|
def testGrid2LSTMCellLegacy(self):
|
||||||
"""Test for legacy case (when state_is_tuple=False)
|
"""Test for legacy case (when state_is_tuple=False)."""
|
||||||
"""
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
with variable_scope.variable_scope('root',
|
with variable_scope.variable_scope(
|
||||||
initializer=init_ops.constant_initializer(0.5)):
|
'root', initializer=init_ops.constant_initializer(0.5)):
|
||||||
x = array_ops.zeros([1, 3])
|
x = array_ops.zeros([1, 3])
|
||||||
m = array_ops.zeros([1, 8])
|
m = array_ops.zeros([1, 8])
|
||||||
cell = grid_rnn_cell.Grid2LSTMCell(2, use_peepholes=True,
|
cell = grid_rnn_cell.Grid2LSTMCell(
|
||||||
state_is_tuple=False,
|
2, use_peepholes=True, state_is_tuple=False, output_is_tuple=False)
|
||||||
output_is_tuple=False)
|
|
||||||
self.assertEqual(cell.state_size, 8)
|
self.assertEqual(cell.state_size, 8)
|
||||||
|
|
||||||
g, s = cell(x, m)
|
g, s = cell(x, m)
|
||||||
@ -729,15 +729,17 @@ class GridRNNCellTest(test.TestCase):
|
|||||||
self.assertEqual(s.get_shape(), (1, 8))
|
self.assertEqual(s.get_shape(), (1, 8))
|
||||||
|
|
||||||
sess.run([variables.global_variables_initializer()])
|
sess.run([variables.global_variables_initializer()])
|
||||||
res = sess.run(
|
res = sess.run([g, s], {
|
||||||
[g, s], {x: np.array([[1., 1., 1.]]),
|
x: np.array([[1., 1., 1.]]),
|
||||||
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])})
|
m: np.array([[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8]])
|
||||||
|
})
|
||||||
self.assertEqual(res[0].shape, (1, 2))
|
self.assertEqual(res[0].shape, (1, 2))
|
||||||
self.assertEqual(res[1].shape, (1, 8))
|
self.assertEqual(res[1].shape, (1, 8))
|
||||||
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
|
self.assertAllClose(res[0], [[0.95686918, 0.95686918]])
|
||||||
self.assertAllClose(res[1], [[2.41515064, 2.41515064, 0.95686918,
|
self.assertAllClose(res[1], [[
|
||||||
0.95686918, 1.38917875, 1.49043763,
|
2.41515064, 2.41515064, 0.95686918, 0.95686918, 1.38917875,
|
||||||
0.83884692, 0.86036491]])
|
1.49043763, 0.83884692, 0.86036491
|
||||||
|
]])
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -102,16 +102,16 @@ class GridRNNCell(rnn.RNNCell):
|
|||||||
output_is_tuple: If True, the output is a tuple of the outputs of the
|
output_is_tuple: If True, the output is a tuple of the outputs of the
|
||||||
recurrent dimensions. If False, they are concatenated along the
|
recurrent dimensions. If False, they are concatenated along the
|
||||||
column axis. The later behavior will soon be deprecated.
|
column axis. The later behavior will soon be deprecated.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
TypeError: if cell_fn does not return an RNNCell instance.
|
TypeError: if cell_fn does not return an RNNCell instance.
|
||||||
"""
|
"""
|
||||||
if not state_is_tuple:
|
if not state_is_tuple:
|
||||||
logging.warning("%s: Using a concatenated state is slower and will "
|
logging.warning('%s: Using a concatenated state is slower and will '
|
||||||
"soon be deprecated. Use state_is_tuple=True.", self)
|
'soon be deprecated. Use state_is_tuple=True.', self)
|
||||||
if not output_is_tuple:
|
if not output_is_tuple:
|
||||||
logging.warning("%s: Using a concatenated output is slower and will"
|
logging.warning('%s: Using a concatenated output is slower and will'
|
||||||
"soon be deprecated. Use output_is_tuple=True.", self)
|
'soon be deprecated. Use output_is_tuple=True.', self)
|
||||||
|
|
||||||
if num_dims < 1:
|
if num_dims < 1:
|
||||||
raise ValueError('dims must be >= 1: {}'.format(num_dims))
|
raise ValueError('dims must be >= 1: {}'.format(num_dims))
|
||||||
@ -126,9 +126,7 @@ class GridRNNCell(rnn.RNNCell):
|
|||||||
|
|
||||||
if cell_fn is None:
|
if cell_fn is None:
|
||||||
my_cell_fn = functools.partial(
|
my_cell_fn = functools.partial(
|
||||||
rnn.LSTMCell,
|
rnn.LSTMCell, num_units=num_units, state_is_tuple=state_is_tuple)
|
||||||
num_units=num_units,
|
|
||||||
state_is_tuple=state_is_tuple)
|
|
||||||
else:
|
else:
|
||||||
my_cell_fn = lambda: cell_fn(num_units)
|
my_cell_fn = lambda: cell_fn(num_units)
|
||||||
if tied:
|
if tied:
|
||||||
@ -136,9 +134,8 @@ class GridRNNCell(rnn.RNNCell):
|
|||||||
else:
|
else:
|
||||||
self._cells = [my_cell_fn() for _ in range(num_dims)]
|
self._cells = [my_cell_fn() for _ in range(num_dims)]
|
||||||
if not isinstance(self._cells[0], rnn.RNNCell):
|
if not isinstance(self._cells[0], rnn.RNNCell):
|
||||||
raise TypeError(
|
raise TypeError('cell_fn must return an RNNCell instance, saw: %s' %
|
||||||
'cell_fn must return an RNNCell instance, saw: %s'
|
type(self._cells[0]))
|
||||||
% type(self._cells[0]))
|
|
||||||
|
|
||||||
if self._output_is_tuple:
|
if self._output_is_tuple:
|
||||||
self._output_size = tuple(self._cells[0].output_size
|
self._output_size = tuple(self._cells[0].output_size
|
||||||
@ -201,26 +198,36 @@ class GridRNNCell(rnn.RNNCell):
|
|||||||
if self._output_is_tuple:
|
if self._output_is_tuple:
|
||||||
output = tuple(output_tensors)
|
output = tuple(output_tensors)
|
||||||
else:
|
else:
|
||||||
if len(output_tensors) == 0:
|
if output_tensors:
|
||||||
output = array_ops.zeros([0, 0], dtype)
|
|
||||||
else:
|
|
||||||
output = array_ops.concat(output_tensors, 1)
|
output = array_ops.concat(output_tensors, 1)
|
||||||
|
else:
|
||||||
|
output = array_ops.zeros([0, 0], dtype)
|
||||||
|
|
||||||
if self._state_is_tuple:
|
if self._state_is_tuple:
|
||||||
states = tuple(new_state[i] for i in self._config.recurrents)
|
states = tuple(new_state[i] for i in self._config.recurrents)
|
||||||
else:
|
else:
|
||||||
# concat each state first, then flatten the whole thing
|
# concat each state first, then flatten the whole thing
|
||||||
state_tensors = [x for i in self._config.recurrents
|
state_tensors = [
|
||||||
for x in new_state[i]]
|
x for i in self._config.recurrents for x in new_state[i]
|
||||||
if len(state_tensors) == 0:
|
]
|
||||||
states = array_ops.zeros([0, 0], dtype)
|
if state_tensors:
|
||||||
else:
|
|
||||||
states = array_ops.concat(state_tensors, 1)
|
states = array_ops.concat(state_tensors, 1)
|
||||||
|
else:
|
||||||
|
states = array_ops.zeros([0, 0], dtype)
|
||||||
|
|
||||||
return output, states
|
return output, states
|
||||||
|
|
||||||
def _extract_states(self, state):
|
def _extract_states(self, state):
|
||||||
"""Extract the cell and previous output tensors from the given state
|
"""Extract the cell and previous output tensors from the given state.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
state: The RNN state.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple of the cell value, previous output, and cell_output_size.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If len(self._config.recurrents) != len(state).
|
||||||
"""
|
"""
|
||||||
conf = self._config
|
conf = self._config
|
||||||
|
|
||||||
@ -238,8 +245,8 @@ class GridRNNCell(rnn.RNNCell):
|
|||||||
|
|
||||||
if self._state_is_tuple:
|
if self._state_is_tuple:
|
||||||
if len(conf.recurrents) != len(state):
|
if len(conf.recurrents) != len(state):
|
||||||
raise ValueError("Expected state as a tuple of {} "
|
raise ValueError('Expected state as a tuple of {} '
|
||||||
"element".format(len(conf.recurrents)))
|
'element'.format(len(conf.recurrents)))
|
||||||
|
|
||||||
for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
|
for recurrent_dim, recurrent_state in zip(conf.recurrents, state):
|
||||||
if cell_output_size > 0:
|
if cell_output_size > 0:
|
||||||
@ -247,49 +254,62 @@ class GridRNNCell(rnn.RNNCell):
|
|||||||
else:
|
else:
|
||||||
m_prev[recurrent_dim] = recurrent_state
|
m_prev[recurrent_dim] = recurrent_state
|
||||||
else:
|
else:
|
||||||
for recurrent_dim, start_idx in zip(conf.recurrents, range(
|
for recurrent_dim, start_idx in zip(conf.recurrents,
|
||||||
0, self.state_size, total_cell_state_size)):
|
range(0, self.state_size,
|
||||||
|
total_cell_state_size)):
|
||||||
if cell_output_size > 0:
|
if cell_output_size > 0:
|
||||||
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
|
c_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
|
||||||
[-1, conf.num_units])
|
[-1, conf.num_units])
|
||||||
m_prev[recurrent_dim] = array_ops.slice(
|
m_prev[recurrent_dim] = array_ops.slice(
|
||||||
state, [0, start_idx + conf.num_units], [-1, cell_output_size])
|
state, [0, start_idx + conf.num_units], [-1, cell_output_size])
|
||||||
else:
|
else:
|
||||||
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
|
m_prev[recurrent_dim] = array_ops.slice(state, [0, start_idx],
|
||||||
[-1, conf.num_units])
|
[-1, conf.num_units])
|
||||||
return c_prev, m_prev, cell_output_size
|
return c_prev, m_prev, cell_output_size
|
||||||
|
|
||||||
def _project_input(self, inputs, c_prev, m_prev, with_c):
|
def _project_input(self, inputs, c_prev, m_prev, with_c):
|
||||||
"""Fills in c_prev and m_prev with projected input, for input dimensions
|
"""Fills in c_prev and m_prev with projected input, for input dimensions.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
inputs: inputs tensor
|
||||||
|
c_prev: cell value
|
||||||
|
m_prev: previous output
|
||||||
|
with_c: boolean; whether to include project_c.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if len(self._config.input) != len(inputs)
|
||||||
"""
|
"""
|
||||||
conf = self._config
|
conf = self._config
|
||||||
|
|
||||||
if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0
|
if (inputs is not None and inputs.get_shape().with_rank(2)[1].value > 0 and
|
||||||
and len(conf.inputs) > 0):
|
conf.inputs):
|
||||||
if isinstance(inputs, tuple):
|
if isinstance(inputs, tuple):
|
||||||
if len(conf.inputs) != len(inputs):
|
if len(conf.inputs) != len(inputs):
|
||||||
raise ValueError("Expect inputs as a tuple of {} "
|
raise ValueError('Expect inputs as a tuple of {} '
|
||||||
"tensors".format(len(conf.inputs)))
|
'tensors'.format(len(conf.inputs)))
|
||||||
input_splits = inputs
|
input_splits = inputs
|
||||||
else:
|
else:
|
||||||
input_splits = array_ops.split(
|
input_splits = array_ops.split(
|
||||||
value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
|
value=inputs, num_or_size_splits=len(conf.inputs), axis=1)
|
||||||
input_sz = input_splits[0].get_shape().with_rank(2)[1].value
|
input_sz = input_splits[0].get_shape().with_rank(2)[1].value
|
||||||
|
|
||||||
for i, j in enumerate(conf.inputs):
|
for i, j in enumerate(conf.inputs):
|
||||||
input_project_m = vs.get_variable(
|
input_project_m = vs.get_variable(
|
||||||
'project_m_{}'.format(j), [input_sz, conf.num_units],
|
'project_m_{}'.format(j), [input_sz, conf.num_units],
|
||||||
dtype=inputs.dtype)
|
dtype=inputs.dtype)
|
||||||
m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
|
m_prev[j] = math_ops.matmul(input_splits[i], input_project_m)
|
||||||
|
|
||||||
if with_c:
|
if with_c:
|
||||||
input_project_c = vs.get_variable(
|
input_project_c = vs.get_variable(
|
||||||
'project_c_{}'.format(j), [input_sz, conf.num_units],
|
'project_c_{}'.format(j), [input_sz, conf.num_units],
|
||||||
dtype=inputs.dtype)
|
dtype=inputs.dtype)
|
||||||
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
|
c_prev[j] = math_ops.matmul(input_splits[i], input_project_c)
|
||||||
|
|
||||||
def _cell_state_size(self):
|
def _cell_state_size(self):
|
||||||
"""Total size of the state of the inner cell used in this grid
|
"""Total size of the state of the inner cell used in this grid.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Total size of the state of the inner cell.
|
||||||
"""
|
"""
|
||||||
state_sizes = self._cells[0].state_size
|
state_sizes = self._cells[0].state_size
|
||||||
if isinstance(state_sizes, tuple):
|
if isinstance(state_sizes, tuple):
|
||||||
@ -306,10 +326,15 @@ class Grid1BasicRNNCell(GridRNNCell):
|
|||||||
|
|
||||||
def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
|
def __init__(self, num_units, state_is_tuple=True, output_is_tuple=True):
|
||||||
super(Grid1BasicRNNCell, self).__init__(
|
super(Grid1BasicRNNCell, self).__init__(
|
||||||
num_units=num_units, num_dims=1,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
num_dims=1,
|
||||||
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
|
input_dims=0,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
output_dims=0,
|
||||||
|
priority_dims=0,
|
||||||
|
tied=False,
|
||||||
|
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
class Grid2BasicRNNCell(GridRNNCell):
|
class Grid2BasicRNNCell(GridRNNCell):
|
||||||
@ -322,38 +347,56 @@ class Grid2BasicRNNCell(GridRNNCell):
|
|||||||
specified.
|
specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
|
def __init__(self,
|
||||||
state_is_tuple=True, output_is_tuple=True):
|
num_units,
|
||||||
|
tied=False,
|
||||||
|
non_recurrent_fn=None,
|
||||||
|
state_is_tuple=True,
|
||||||
|
output_is_tuple=True):
|
||||||
super(Grid2BasicRNNCell, self).__init__(
|
super(Grid2BasicRNNCell, self).__init__(
|
||||||
num_units=num_units, num_dims=2,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
num_dims=2,
|
||||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
input_dims=0,
|
||||||
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
|
output_dims=0,
|
||||||
non_recurrent_fn=non_recurrent_fn,
|
priority_dims=0,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
tied=tied,
|
||||||
|
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||||
|
cell_fn=lambda n: rnn.BasicRNNCell(num_units=n),
|
||||||
|
non_recurrent_fn=non_recurrent_fn,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
class Grid1BasicLSTMCell(GridRNNCell):
|
class Grid1BasicLSTMCell(GridRNNCell):
|
||||||
"""1D BasicLSTM cell"""
|
"""1D BasicLSTM cell."""
|
||||||
|
|
||||||
def __init__(self, num_units, forget_bias=1,
|
def __init__(self,
|
||||||
state_is_tuple=True, output_is_tuple=True):
|
num_units,
|
||||||
|
forget_bias=1,
|
||||||
|
state_is_tuple=True,
|
||||||
|
output_is_tuple=True):
|
||||||
|
def cell_fn(n):
|
||||||
|
return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
|
||||||
super(Grid1BasicLSTMCell, self).__init__(
|
super(Grid1BasicLSTMCell, self).__init__(
|
||||||
num_units=num_units, num_dims=1,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0, tied=False,
|
num_dims=1,
|
||||||
cell_fn=lambda n: rnn.BasicLSTMCell(
|
input_dims=0,
|
||||||
num_units=n, forget_bias=forget_bias),
|
output_dims=0,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
priority_dims=0,
|
||||||
|
tied=False,
|
||||||
|
cell_fn=cell_fn,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
class Grid2BasicLSTMCell(GridRNNCell):
|
class Grid2BasicLSTMCell(GridRNNCell):
|
||||||
"""2D BasicLSTM cell
|
"""2D BasicLSTM cell.
|
||||||
|
|
||||||
This creates a 2D cell which receives input and gives output in the first
|
This creates a 2D cell which receives input and gives output in the first
|
||||||
dimension.
|
dimension.
|
||||||
|
|
||||||
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
|
The first dimension can optionally be non-recurrent if `non_recurrent_fn` is
|
||||||
specified.
|
specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self,
|
def __init__(self,
|
||||||
@ -363,36 +406,53 @@ class Grid2BasicLSTMCell(GridRNNCell):
|
|||||||
forget_bias=1,
|
forget_bias=1,
|
||||||
state_is_tuple=True,
|
state_is_tuple=True,
|
||||||
output_is_tuple=True):
|
output_is_tuple=True):
|
||||||
|
def cell_fn(n):
|
||||||
|
return rnn.BasicLSTMCell(num_units=n, forget_bias=forget_bias)
|
||||||
super(Grid2BasicLSTMCell, self).__init__(
|
super(Grid2BasicLSTMCell, self).__init__(
|
||||||
num_units=num_units, num_dims=2,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
num_dims=2,
|
||||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
input_dims=0,
|
||||||
cell_fn=lambda n: rnn.BasicLSTMCell(
|
output_dims=0,
|
||||||
num_units=n, forget_bias=forget_bias),
|
priority_dims=0,
|
||||||
non_recurrent_fn=non_recurrent_fn,
|
tied=tied,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||||
|
cell_fn=cell_fn,
|
||||||
|
non_recurrent_fn=non_recurrent_fn,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
class Grid1LSTMCell(GridRNNCell):
|
class Grid1LSTMCell(GridRNNCell):
|
||||||
"""1D LSTM cell
|
"""1D LSTM cell.
|
||||||
|
|
||||||
This is different from Grid1BasicLSTMCell because it gives options to
|
This is different from Grid1BasicLSTMCell because it gives options to
|
||||||
specify the forget bias and enabling peepholes
|
specify the forget bias and enabling peepholes.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_units, use_peepholes=False, forget_bias=1.0,
|
def __init__(self,
|
||||||
state_is_tuple=True, output_is_tuple=True):
|
num_units,
|
||||||
|
use_peepholes=False,
|
||||||
|
forget_bias=1.0,
|
||||||
|
state_is_tuple=True,
|
||||||
|
output_is_tuple=True):
|
||||||
|
|
||||||
|
def cell_fn(n):
|
||||||
|
return rnn.LSTMCell(
|
||||||
|
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
|
||||||
|
|
||||||
super(Grid1LSTMCell, self).__init__(
|
super(Grid1LSTMCell, self).__init__(
|
||||||
num_units=num_units, num_dims=1,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0,
|
num_dims=1,
|
||||||
cell_fn=lambda n: rnn.LSTMCell(
|
input_dims=0,
|
||||||
num_units=n, use_peepholes=use_peepholes,
|
output_dims=0,
|
||||||
forget_bias=forget_bias),
|
priority_dims=0,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
cell_fn=cell_fn,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
class Grid2LSTMCell(GridRNNCell):
|
class Grid2LSTMCell(GridRNNCell):
|
||||||
"""2D LSTM cell
|
"""2D LSTM cell.
|
||||||
|
|
||||||
This creates a 2D cell which receives input and gives output in the first
|
This creates a 2D cell which receives input and gives output in the first
|
||||||
dimension.
|
dimension.
|
||||||
@ -408,19 +468,27 @@ class Grid2LSTMCell(GridRNNCell):
|
|||||||
forget_bias=1.0,
|
forget_bias=1.0,
|
||||||
state_is_tuple=True,
|
state_is_tuple=True,
|
||||||
output_is_tuple=True):
|
output_is_tuple=True):
|
||||||
|
|
||||||
|
def cell_fn(n):
|
||||||
|
return rnn.LSTMCell(
|
||||||
|
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
|
||||||
|
|
||||||
super(Grid2LSTMCell, self).__init__(
|
super(Grid2LSTMCell, self).__init__(
|
||||||
num_units=num_units, num_dims=2,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
num_dims=2,
|
||||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
input_dims=0,
|
||||||
cell_fn=lambda n: rnn.LSTMCell(
|
output_dims=0,
|
||||||
num_units=n, forget_bias=forget_bias,
|
priority_dims=0,
|
||||||
use_peepholes=use_peepholes),
|
tied=tied,
|
||||||
non_recurrent_fn=non_recurrent_fn,
|
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
cell_fn=cell_fn,
|
||||||
|
non_recurrent_fn=non_recurrent_fn,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
class Grid3LSTMCell(GridRNNCell):
|
class Grid3LSTMCell(GridRNNCell):
|
||||||
"""3D BasicLSTM cell
|
"""3D BasicLSTM cell.
|
||||||
|
|
||||||
This creates a 2D cell which receives input and gives output in the first
|
This creates a 2D cell which receives input and gives output in the first
|
||||||
dimension.
|
dimension.
|
||||||
@ -437,19 +505,27 @@ class Grid3LSTMCell(GridRNNCell):
|
|||||||
forget_bias=1.0,
|
forget_bias=1.0,
|
||||||
state_is_tuple=True,
|
state_is_tuple=True,
|
||||||
output_is_tuple=True):
|
output_is_tuple=True):
|
||||||
|
|
||||||
|
def cell_fn(n):
|
||||||
|
return rnn.LSTMCell(
|
||||||
|
num_units=n, forget_bias=forget_bias, use_peepholes=use_peepholes)
|
||||||
|
|
||||||
super(Grid3LSTMCell, self).__init__(
|
super(Grid3LSTMCell, self).__init__(
|
||||||
num_units=num_units, num_dims=3,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
num_dims=3,
|
||||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
input_dims=0,
|
||||||
cell_fn=lambda n: rnn.LSTMCell(
|
output_dims=0,
|
||||||
num_units=n, forget_bias=forget_bias,
|
priority_dims=0,
|
||||||
use_peepholes=use_peepholes),
|
tied=tied,
|
||||||
non_recurrent_fn=non_recurrent_fn,
|
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
cell_fn=cell_fn,
|
||||||
|
non_recurrent_fn=non_recurrent_fn,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
class Grid2GRUCell(GridRNNCell):
|
class Grid2GRUCell(GridRNNCell):
|
||||||
"""2D LSTM cell
|
"""2D LSTM cell.
|
||||||
|
|
||||||
This creates a 2D cell which receives input and gives output in the first
|
This creates a 2D cell which receives input and gives output in the first
|
||||||
dimension.
|
dimension.
|
||||||
@ -457,23 +533,31 @@ class Grid2GRUCell(GridRNNCell):
|
|||||||
specified.
|
specified.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, num_units, tied=False, non_recurrent_fn=None,
|
def __init__(self,
|
||||||
state_is_tuple=True, output_is_tuple=True):
|
num_units,
|
||||||
|
tied=False,
|
||||||
|
non_recurrent_fn=None,
|
||||||
|
state_is_tuple=True,
|
||||||
|
output_is_tuple=True):
|
||||||
super(Grid2GRUCell, self).__init__(
|
super(Grid2GRUCell, self).__init__(
|
||||||
num_units=num_units, num_dims=2,
|
num_units=num_units,
|
||||||
input_dims=0, output_dims=0, priority_dims=0, tied=tied,
|
num_dims=2,
|
||||||
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
input_dims=0,
|
||||||
cell_fn=lambda n: rnn.GRUCell(num_units=n),
|
output_dims=0,
|
||||||
non_recurrent_fn=non_recurrent_fn,
|
priority_dims=0,
|
||||||
state_is_tuple=state_is_tuple, output_is_tuple=output_is_tuple)
|
tied=tied,
|
||||||
|
non_recurrent_dims=None if non_recurrent_fn is None else 0,
|
||||||
|
cell_fn=lambda n: rnn.GRUCell(num_units=n),
|
||||||
|
non_recurrent_fn=non_recurrent_fn,
|
||||||
|
state_is_tuple=state_is_tuple,
|
||||||
|
output_is_tuple=output_is_tuple)
|
||||||
|
|
||||||
|
|
||||||
"""Helpers
|
# Helpers
|
||||||
"""
|
|
||||||
|
|
||||||
_GridRNNDimension = namedtuple(
|
_GridRNNDimension = namedtuple('_GridRNNDimension', [
|
||||||
'_GridRNNDimension',
|
'idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'
|
||||||
['idx', 'is_input', 'is_output', 'is_priority', 'non_recurrent_fn'])
|
])
|
||||||
|
|
||||||
_GridRNNConfig = namedtuple('_GridRNNConfig',
|
_GridRNNConfig = namedtuple('_GridRNNConfig',
|
||||||
['num_dims', 'dims', 'inputs', 'outputs',
|
['num_dims', 'dims', 'inputs', 'outputs',
|
||||||
@ -502,23 +586,23 @@ def _parse_rnn_config(num_dims, ls_input_dims, ls_output_dims, ls_priority_dims,
|
|||||||
rnn_dims = []
|
rnn_dims = []
|
||||||
for i in range(num_dims):
|
for i in range(num_dims):
|
||||||
rnn_dims.append(
|
rnn_dims.append(
|
||||||
_GridRNNDimension(
|
_GridRNNDimension(
|
||||||
idx=i,
|
idx=i,
|
||||||
is_input=(i in input_dims),
|
is_input=(i in input_dims),
|
||||||
is_output=(i in output_dims),
|
is_output=(i in output_dims),
|
||||||
is_priority=(i in priority_dims),
|
is_priority=(i in priority_dims),
|
||||||
non_recurrent_fn=non_recurrent_fn if i in non_recurrent_dims else
|
non_recurrent_fn=non_recurrent_fn
|
||||||
None))
|
if i in non_recurrent_dims else None))
|
||||||
return _GridRNNConfig(
|
return _GridRNNConfig(
|
||||||
num_dims=num_dims,
|
num_dims=num_dims,
|
||||||
dims=rnn_dims,
|
dims=rnn_dims,
|
||||||
inputs=input_dims,
|
inputs=input_dims,
|
||||||
outputs=output_dims,
|
outputs=output_dims,
|
||||||
recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
|
recurrents=[x for x in range(num_dims) if x not in non_recurrent_dims],
|
||||||
priority=priority_dims,
|
priority=priority_dims,
|
||||||
non_priority=[x for x in range(num_dims) if x not in priority_dims],
|
non_priority=[x for x in range(num_dims) if x not in priority_dims],
|
||||||
tied=tied,
|
tied=tied,
|
||||||
num_units=num_units)
|
num_units=num_units)
|
||||||
|
|
||||||
|
|
||||||
def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
|
def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
|
||||||
@ -544,8 +628,8 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
|
|||||||
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
|
cell_inputs = array_ops.zeros([m_prev[0].get_shape().as_list()[0], 0],
|
||||||
m_prev[0].dtype)
|
m_prev[0].dtype)
|
||||||
|
|
||||||
last_dim_output = (new_output[-1] if new_output[-1] is not None
|
last_dim_output = (new_output[-1]
|
||||||
else m_prev[-1])
|
if new_output[-1] is not None else m_prev[-1])
|
||||||
|
|
||||||
for i in dim_indices:
|
for i in dim_indices:
|
||||||
d = conf.dims[i]
|
d = conf.dims[i]
|
||||||
@ -560,12 +644,12 @@ def _propagate(dim_indices, conf, cells, c_prev, m_prev, new_output, new_state,
|
|||||||
vs.get_variable_scope().reuse_variables()
|
vs.get_variable_scope().reuse_variables()
|
||||||
|
|
||||||
new_output[d.idx] = layers.fully_connected(
|
new_output[d.idx] = layers.fully_connected(
|
||||||
linear_args,
|
linear_args,
|
||||||
num_outputs=conf.num_units,
|
num_outputs=conf.num_units,
|
||||||
activation_fn=d.non_recurrent_fn,
|
activation_fn=d.non_recurrent_fn,
|
||||||
weights_initializer=vs.get_variable_scope().initializer or
|
weights_initializer=(vs.get_variable_scope().initializer or
|
||||||
layers.initializers.xavier_initializer,
|
layers.initializers.xavier_initializer),
|
||||||
weights_regularizer=vs.get_variable_scope().regularizer)
|
weights_regularizer=vs.get_variable_scope().regularizer)
|
||||||
else:
|
else:
|
||||||
if c_prev[i] is not None:
|
if c_prev[i] is not None:
|
||||||
cell_state = (c_prev[i], last_dim_output)
|
cell_state = (c_prev[i], last_dim_output)
|
||||||
|
@ -43,13 +43,29 @@ template class FillProjectiveTransform<CPUDevice, double>;
|
|||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
|
|
||||||
using functor::FillProjectiveTransform;
|
using functor::FillProjectiveTransform;
|
||||||
|
using generator::INTERPOLATION_BILINEAR;
|
||||||
|
using generator::INTERPOLATION_NEAREST;
|
||||||
|
using generator::Interpolation;
|
||||||
using generator::ProjectiveGenerator;
|
using generator::ProjectiveGenerator;
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class ImageProjectiveTransform : public OpKernel {
|
class ImageProjectiveTransform : public OpKernel {
|
||||||
|
private:
|
||||||
|
Interpolation interpolation_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
explicit ImageProjectiveTransform(OpKernelConstruction* ctx)
|
explicit ImageProjectiveTransform(OpKernelConstruction* ctx) : OpKernel(ctx) {
|
||||||
: OpKernel(ctx) {}
|
string interpolation_str;
|
||||||
|
OP_REQUIRES_OK(ctx, ctx->GetAttr("interpolation", &interpolation_str));
|
||||||
|
if (interpolation_str == "NEAREST") {
|
||||||
|
interpolation_ = INTERPOLATION_NEAREST;
|
||||||
|
} else if (interpolation_str == "BILINEAR") {
|
||||||
|
interpolation_ = INTERPOLATION_BILINEAR;
|
||||||
|
} else {
|
||||||
|
LOG(FATAL) << "Invalid interpolation " << interpolation_str
|
||||||
|
<< ". Supported types: NEAREST, BILINEAR";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* ctx) override {
|
void Compute(OpKernelContext* ctx) override {
|
||||||
const Tensor& images_t = ctx->input(0);
|
const Tensor& images_t = ctx->input(0);
|
||||||
@ -68,8 +84,8 @@ class ImageProjectiveTransform : public OpKernel {
|
|||||||
Tensor* output_t;
|
Tensor* output_t;
|
||||||
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
|
OP_REQUIRES_OK(ctx, ctx->allocate_output(0, images_t.shape(), &output_t));
|
||||||
auto output = output_t->tensor<T, 4>();
|
auto output = output_t->tensor<T, 4>();
|
||||||
const FillProjectiveTransform<Device, T> functor;
|
(FillProjectiveTransform<Device, T>(interpolation_))(
|
||||||
functor(ctx->eigen_device<Device>(), &output, images, transform);
|
ctx->eigen_device<Device>(), &output, images, transform);
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -28,6 +28,8 @@ namespace tensorflow {
|
|||||||
|
|
||||||
namespace generator {
|
namespace generator {
|
||||||
|
|
||||||
|
enum Interpolation { INTERPOLATION_NEAREST, INTERPOLATION_BILINEAR };
|
||||||
|
|
||||||
using Eigen::array;
|
using Eigen::array;
|
||||||
using Eigen::DenseIndex;
|
using Eigen::DenseIndex;
|
||||||
|
|
||||||
@ -36,20 +38,19 @@ class ProjectiveGenerator {
|
|||||||
private:
|
private:
|
||||||
typename TTypes<T, 4>::ConstTensor input_;
|
typename TTypes<T, 4>::ConstTensor input_;
|
||||||
typename TTypes<float>::ConstMatrix transforms_;
|
typename TTypes<float>::ConstMatrix transforms_;
|
||||||
|
const Interpolation interpolation_;
|
||||||
|
|
||||||
public:
|
public:
|
||||||
static const int kNumParameters = 8;
|
static const int kNumParameters = 8;
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE
|
||||||
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
ProjectiveGenerator(typename TTypes<T, 4>::ConstTensor input,
|
||||||
typename TTypes<float>::ConstMatrix transforms)
|
typename TTypes<float>::ConstMatrix transforms,
|
||||||
: input_(input), transforms_(transforms) {}
|
const Interpolation interpolation)
|
||||||
|
: input_(input), transforms_(transforms), interpolation_(interpolation) {}
|
||||||
|
|
||||||
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
operator()(const array<DenseIndex, 4>& coords) const {
|
operator()(const array<DenseIndex, 4>& coords) const {
|
||||||
array<DenseIndex, 4> input_coords;
|
|
||||||
input_coords[0] = coords[0];
|
|
||||||
|
|
||||||
const int64 output_y = coords[1];
|
const int64 output_y = coords[1];
|
||||||
const int64 output_x = coords[2];
|
const int64 output_x = coords[2];
|
||||||
const float* transform =
|
const float* transform =
|
||||||
@ -57,24 +58,73 @@ class ProjectiveGenerator {
|
|||||||
? transforms_.data()
|
? transforms_.data()
|
||||||
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
: &transforms_.data()[transforms_.dimension(1) * coords[0]];
|
||||||
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
float projection = transform[6] * output_x + transform[7] * output_y + 1.f;
|
||||||
const int64 input_x = std::round(
|
const float input_x =
|
||||||
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
(transform[0] * output_x + transform[1] * output_y + transform[2]) /
|
||||||
projection);
|
projection;
|
||||||
const int64 input_y = std::round(
|
const float input_y =
|
||||||
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
|
(transform[3] * output_x + transform[4] * output_y + transform[5]) /
|
||||||
projection);
|
projection;
|
||||||
|
|
||||||
if (!(0 <= input_y && input_y < input_.dimension(1) && 0 <= input_x &&
|
// TODO(ringwalt): Add a fill value input.
|
||||||
input_x < input_.dimension(2))) {
|
static const T fill_value = T(0);
|
||||||
// TODO(ringwalt): Add a fill value input.
|
switch (interpolation_) {
|
||||||
return T(0);
|
case INTERPOLATION_NEAREST:
|
||||||
|
// Switch the order of x and y again for indexing into the image.
|
||||||
|
return nearest_interpolation(coords[0], input_y, input_x, coords[3],
|
||||||
|
fill_value);
|
||||||
|
case INTERPOLATION_BILINEAR:
|
||||||
|
return bilinear_interpolation(coords[0], input_y, input_x, coords[3],
|
||||||
|
fill_value);
|
||||||
}
|
}
|
||||||
input_coords[1] = input_y;
|
}
|
||||||
input_coords[2] = input_x;
|
|
||||||
|
|
||||||
input_coords[3] = coords[3];
|
private:
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
|
nearest_interpolation(const DenseIndex batch, const float y, const float x,
|
||||||
|
const DenseIndex channel, const T fill_value) const {
|
||||||
|
return read_with_fill_value(batch, DenseIndex(std::round(y)),
|
||||||
|
DenseIndex(std::round(x)), channel, fill_value);
|
||||||
|
}
|
||||||
|
|
||||||
return input_(input_coords);
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T
|
||||||
|
bilinear_interpolation(const DenseIndex batch, const float y, const float x,
|
||||||
|
const DenseIndex channel, const T fill_value) const {
|
||||||
|
const float y_floor = std::floor(y);
|
||||||
|
const float x_floor = std::floor(x);
|
||||||
|
const float y_ceil = y_floor + 1;
|
||||||
|
const float x_ceil = x_floor + 1;
|
||||||
|
// f(x, y_floor) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_floor)
|
||||||
|
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_floor)
|
||||||
|
const float value_yfloor =
|
||||||
|
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||||
|
DenseIndex(x_floor), channel,
|
||||||
|
fill_value) +
|
||||||
|
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_floor),
|
||||||
|
DenseIndex(x_ceil), channel,
|
||||||
|
fill_value);
|
||||||
|
// f(x, y_ceil) = (x_ceil - x) / (x_ceil - x_floor) * f(x_floor, y_ceil)
|
||||||
|
// + (x - x_floor) / (x_ceil - x_floor) * f(x_ceil, y_ceil)
|
||||||
|
const float value_yceil =
|
||||||
|
(x_ceil - x) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||||
|
DenseIndex(x_floor), channel,
|
||||||
|
fill_value) +
|
||||||
|
(x - x_floor) * read_with_fill_value(batch, DenseIndex(y_ceil),
|
||||||
|
DenseIndex(x_ceil), channel,
|
||||||
|
fill_value);
|
||||||
|
// f(x, y) = (y_ceil - y) / (y_ceil - y_floor) * f(x, y_floor)
|
||||||
|
// + (y - y_floor) / (y_ceil - y_floor) * f(x, y_ceil)
|
||||||
|
return T((y_ceil - y) * value_yfloor + (y - y_floor) * value_yceil);
|
||||||
|
}
|
||||||
|
|
||||||
|
EIGEN_DEVICE_FUNC EIGEN_ALWAYS_INLINE T read_with_fill_value(
|
||||||
|
const DenseIndex batch, const DenseIndex y, const DenseIndex x,
|
||||||
|
const DenseIndex channel, const T fill_value) const {
|
||||||
|
// batch and channel must be correct, because they are passed unchanged from
|
||||||
|
// the input.
|
||||||
|
return (0 <= y && y < input_.dimension(1) && 0 <= x &&
|
||||||
|
x < input_.dimension(2))
|
||||||
|
? input_(array<DenseIndex, 4>{batch, y, x, channel})
|
||||||
|
: fill_value;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -85,6 +135,7 @@ class ProjectiveGenerator {
|
|||||||
// some Eigen device code.
|
// some Eigen device code.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
|
|
||||||
|
using generator::Interpolation;
|
||||||
using generator::ProjectiveGenerator;
|
using generator::ProjectiveGenerator;
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
@ -92,15 +143,17 @@ struct FillProjectiveTransform {
|
|||||||
typedef typename TTypes<T, 4>::Tensor OutputType;
|
typedef typename TTypes<T, 4>::Tensor OutputType;
|
||||||
typedef typename TTypes<T, 4>::ConstTensor InputType;
|
typedef typename TTypes<T, 4>::ConstTensor InputType;
|
||||||
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
|
typedef typename TTypes<float, 2>::ConstTensor TransformsType;
|
||||||
|
const Interpolation interpolation_;
|
||||||
|
|
||||||
FillProjectiveTransform() {}
|
FillProjectiveTransform(Interpolation interpolation)
|
||||||
|
: interpolation_(interpolation) {}
|
||||||
|
|
||||||
EIGEN_ALWAYS_INLINE
|
EIGEN_ALWAYS_INLINE
|
||||||
void operator()(const Device& device, OutputType* output,
|
void operator()(const Device& device, OutputType* output,
|
||||||
const InputType& images,
|
const InputType& images,
|
||||||
const TransformsType& transform) const {
|
const TransformsType& transform) const {
|
||||||
ProjectiveGenerator<Device, T> generator(images, transform);
|
output->device(device) = images.generate(
|
||||||
output->device(device) = images.generate(generator);
|
ProjectiveGenerator<Device, T>(images, transform, interpolation_));
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -23,13 +23,13 @@ using shape_inference::InferenceContext;
|
|||||||
|
|
||||||
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
|
// TODO(ringwalt): Add a "fill_mode" argument with "constant", "mirror", etc.
|
||||||
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
// TODO(ringwalt): Add a "fill_constant" argument for constant mode (default 0).
|
||||||
// TODO(ringwalt): Add an "interpolation" argument with "none", "bilinear", etc.
|
|
||||||
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
|
// TODO(ringwalt): Add an "output_shape" argument. This is sufficient to
|
||||||
// implement "same" and "valid" modes in the Python function.
|
// implement "same" and "valid" modes in the Python function.
|
||||||
REGISTER_OP("ImageProjectiveTransform")
|
REGISTER_OP("ImageProjectiveTransform")
|
||||||
.Input("images: dtype")
|
.Input("images: dtype")
|
||||||
.Input("transforms: float32")
|
.Input("transforms: float32")
|
||||||
.Attr("dtype: {uint8, int32, int64, float32, float64}")
|
.Attr("dtype: {uint8, int32, int64, float32, float64}")
|
||||||
|
.Attr("interpolation: string")
|
||||||
.Output("transformed_images: dtype")
|
.Output("transformed_images: dtype")
|
||||||
.SetShapeFn([](InferenceContext* c) {
|
.SetShapeFn([](InferenceContext* c) {
|
||||||
c->set_output(0, c->input(0));
|
c->set_output(0, c->input(0));
|
||||||
|
@ -25,6 +25,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import test_util
|
from tensorflow.python.framework import test_util
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import gradient_checker
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import gradient_checker
|
from tensorflow.python.ops import gradient_checker
|
||||||
from tensorflow.python.platform import googletest
|
from tensorflow.python.platform import googletest
|
||||||
@ -111,6 +112,79 @@ class ImageOpsTest(test_util.TensorFlowTestCase):
|
|||||||
[0, 1, 0, 1],
|
[0, 1, 0, 1],
|
||||||
[0, 1, 1, 1]])
|
[0, 1, 1, 1]])
|
||||||
|
|
||||||
|
def test_bilinear(self):
|
||||||
|
with self.test_session():
|
||||||
|
image = constant_op.constant(
|
||||||
|
[[0, 0, 0, 0, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 1, 0, 1, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 0, 0, 0]],
|
||||||
|
dtypes.float32)
|
||||||
|
# The following result matches:
|
||||||
|
# >>> scipy.ndimage.rotate(image, 45, order=1, reshape=False)
|
||||||
|
# which uses spline interpolation of order 1, equivalent to bilinear
|
||||||
|
# interpolation.
|
||||||
|
self.assertAllClose(
|
||||||
|
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||||
|
[[0.000, 0.000, 0.343, 0.000, 0.000],
|
||||||
|
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||||
|
[0.343, 0.914, 0.000, 0.914, 0.343],
|
||||||
|
[0.000, 0.586, 0.914, 0.586, 0.000],
|
||||||
|
[0.000, 0.000, 0.343, 0.000, 0.000]],
|
||||||
|
atol=0.001)
|
||||||
|
self.assertAllClose(
|
||||||
|
image_ops.rotate(image, np.pi / 4.0, interpolation="NEAREST").eval(),
|
||||||
|
[[0, 0, 1, 0, 0],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[1, 1, 0, 1, 1],
|
||||||
|
[0, 1, 1, 1, 0],
|
||||||
|
[0, 0, 1, 0, 0]])
|
||||||
|
|
||||||
|
def test_bilinear_uint8(self):
|
||||||
|
with self.test_session():
|
||||||
|
image = constant_op.constant(
|
||||||
|
np.asarray(
|
||||||
|
[[0.0, 0.0, 0.0, 0.0, 0.0],
|
||||||
|
[0.0, 255, 255, 255, 0.0],
|
||||||
|
[0.0, 255, 0.0, 255, 0.0],
|
||||||
|
[0.0, 255, 255, 255, 0.0],
|
||||||
|
[0.0, 0.0, 0.0, 0.0, 0.0]],
|
||||||
|
np.uint8),
|
||||||
|
dtypes.uint8)
|
||||||
|
# == np.rint((expected image above) * 255)
|
||||||
|
self.assertAllEqual(
|
||||||
|
image_ops.rotate(image, np.pi / 4.0, interpolation="BILINEAR").eval(),
|
||||||
|
[[0.0, 0.0, 87., 0.0, 0.0],
|
||||||
|
[0.0, 149, 233, 149, 0.0],
|
||||||
|
[87., 233, 0.0, 233, 87.],
|
||||||
|
[0.0, 149, 233, 149, 0.0],
|
||||||
|
[0.0, 0.0, 87., 0.0, 0.0]])
|
||||||
|
|
||||||
|
def _test_grad(self, shape_to_test):
|
||||||
|
with self.test_session():
|
||||||
|
test_image_shape = shape_to_test
|
||||||
|
test_image = np.random.randn(*test_image_shape)
|
||||||
|
test_image_tensor = constant_op.constant(
|
||||||
|
test_image, shape=test_image_shape)
|
||||||
|
test_transform = image_ops.angles_to_projective_transforms(
|
||||||
|
np.pi / 2, 4, 4)
|
||||||
|
|
||||||
|
output_shape = test_image_shape
|
||||||
|
output = image_ops.transform(test_image_tensor, test_transform)
|
||||||
|
left_err = gradient_checker.compute_gradient_error(
|
||||||
|
test_image_tensor,
|
||||||
|
test_image_shape,
|
||||||
|
output,
|
||||||
|
output_shape,
|
||||||
|
x_init_value=test_image)
|
||||||
|
self.assertLess(left_err, 1e-10)
|
||||||
|
|
||||||
|
def test_grad(self):
|
||||||
|
self._test_grad([16, 16])
|
||||||
|
self._test_grad([4, 12, 12])
|
||||||
|
self._test_grad([3, 4, 12, 12])
|
||||||
|
|
||||||
|
|
||||||
def _test_grad(self, shape_to_test):
|
def _test_grad(self, shape_to_test):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
|
@ -24,8 +24,8 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.platform import resource_loader
|
from tensorflow.python.platform import resource_loader
|
||||||
|
|
||||||
_image_ops_so = loader.load_op_library(
|
_image_ops_so = loader.load_op_library(
|
||||||
@ -37,7 +37,7 @@ _IMAGE_DTYPES = set(
|
|||||||
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
|
ops.RegisterShape("ImageProjectiveTransform")(common_shapes.call_cpp_shape_fn)
|
||||||
|
|
||||||
|
|
||||||
def rotate(images, angles):
|
def rotate(images, angles, interpolation="NEAREST"):
|
||||||
"""Rotate image(s) by the passed angle(s) in radians.
|
"""Rotate image(s) by the passed angle(s) in radians.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -46,6 +46,7 @@ def rotate(images, angles):
|
|||||||
(num_rows, num_columns) (HW).
|
(num_rows, num_columns) (HW).
|
||||||
angles: A scalar angle to rotate all images by, or (if images has rank 4)
|
angles: A scalar angle to rotate all images by, or (if images has rank 4)
|
||||||
a vector of length num_images, with an angle for each image in the batch.
|
a vector of length num_images, with an angle for each image in the batch.
|
||||||
|
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image(s) with the same type and shape as `images`, rotated by the given
|
Image(s) with the same type and shape as `images`, rotated by the given
|
||||||
@ -70,7 +71,8 @@ def rotate(images, angles):
|
|||||||
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
|
image_width = math_ops.cast(array_ops.shape(images)[2], dtypes.float32)[None]
|
||||||
output = transform(
|
output = transform(
|
||||||
images,
|
images,
|
||||||
angles_to_projective_transforms(angles, image_width, image_height))
|
angles_to_projective_transforms(angles, image_height, image_width),
|
||||||
|
interpolation=interpolation)
|
||||||
if len(image_or_images.get_shape()) == 2:
|
if len(image_or_images.get_shape()) == 2:
|
||||||
return output[0, :, :, 0]
|
return output[0, :, :, 0]
|
||||||
elif len(image_or_images.get_shape()) == 3:
|
elif len(image_or_images.get_shape()) == 3:
|
||||||
@ -120,7 +122,7 @@ def angles_to_projective_transforms(angles, image_height, image_width):
|
|||||||
axis=1)
|
axis=1)
|
||||||
|
|
||||||
|
|
||||||
def transform(images, transforms):
|
def transform(images, transforms, interpolation="NEAREST"):
|
||||||
"""Applies the given transform(s) to the image(s).
|
"""Applies the given transform(s) to the image(s).
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -134,6 +136,7 @@ def transform(images, transforms):
|
|||||||
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
`(x', y') = ((a0 x + a1 y + a2) / k, (b0 x + b1 y + b2) / k)`,
|
||||||
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
where `k = c0 x + c1 y + 1`. The transforms are *inverted* compared to
|
||||||
the transform mapping input points to output points.
|
the transform mapping input points to output points.
|
||||||
|
interpolation: Interpolation mode. Supported values: "NEAREST", "BILINEAR".
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
Image(s) with the same type and shape as `images`, with the given
|
Image(s) with the same type and shape as `images`, with the given
|
||||||
@ -163,8 +166,8 @@ def transform(images, transforms):
|
|||||||
transforms = transform_or_transforms
|
transforms = transform_or_transforms
|
||||||
else:
|
else:
|
||||||
raise TypeError("Transforms should have rank 1 or 2.")
|
raise TypeError("Transforms should have rank 1 or 2.")
|
||||||
# pylint: disable=protected-access
|
output = gen_image_ops.image_projective_transform(
|
||||||
output = gen_image_ops.image_projective_transform(images, transforms)
|
images, transforms, interpolation=interpolation.upper())
|
||||||
if len(image_or_images.get_shape()) == 2:
|
if len(image_or_images.get_shape()) == 2:
|
||||||
return output[0, :, :, 0]
|
return output[0, :, :, 0]
|
||||||
elif len(image_or_images.get_shape()) == 3:
|
elif len(image_or_images.get_shape()) == 3:
|
||||||
@ -217,8 +220,10 @@ def _transform_matrices_to_flat(transform_matrices):
|
|||||||
|
|
||||||
@ops.RegisterGradient("ImageProjectiveTransform")
|
@ops.RegisterGradient("ImageProjectiveTransform")
|
||||||
def _image_projective_transform_grad(op, grad):
|
def _image_projective_transform_grad(op, grad):
|
||||||
|
"""Computes the gradient for ImageProjectiveTransform."""
|
||||||
images = op.inputs[0]
|
images = op.inputs[0]
|
||||||
transforms = op.inputs[1]
|
transforms = op.inputs[1]
|
||||||
|
interpolation = op.get_attr("interpolation")
|
||||||
|
|
||||||
image_or_images = ops.convert_to_tensor(images, name="images")
|
image_or_images = ops.convert_to_tensor(images, name="images")
|
||||||
transform_or_transforms = ops.convert_to_tensor(
|
transform_or_transforms = ops.convert_to_tensor(
|
||||||
@ -245,7 +250,8 @@ def _image_projective_transform_grad(op, grad):
|
|||||||
transforms = _flat_transforms_to_matrices(transforms=transforms)
|
transforms = _flat_transforms_to_matrices(transforms=transforms)
|
||||||
inverse = linalg_ops.matrix_inverse(transforms)
|
inverse = linalg_ops.matrix_inverse(transforms)
|
||||||
transforms = _transform_matrices_to_flat(inverse)
|
transforms = _transform_matrices_to_flat(inverse)
|
||||||
output = gen_image_ops.image_projective_transform(grad, transforms)
|
output = gen_image_ops.image_projective_transform(
|
||||||
|
grad, transforms, interpolation=interpolation)
|
||||||
if len(image_or_images.get_shape()) == 2:
|
if len(image_or_images.get_shape()) == 2:
|
||||||
return [output[0, :, :, 0], None]
|
return [output[0, :, :, 0], None]
|
||||||
elif len(image_or_images.get_shape()) == 3:
|
elif len(image_or_images.get_shape()) == 3:
|
||||||
|
55
tensorflow/contrib/kernel_methods/README.md
Normal file
55
tensorflow/contrib/kernel_methods/README.md
Normal file
@ -0,0 +1,55 @@
|
|||||||
|
# TensorFlow contrib kernel_methods.
|
||||||
|
|
||||||
|
This module contains operations and estimators that enable the use of primal
|
||||||
|
(explicit) kernel methods in TensorFlow. See also the [tutorial](https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/g3doc/tutorial.md) on how to use this module to improve the quality of
|
||||||
|
classification or regression tasks.
|
||||||
|
|
||||||
|
## Kernel Mappers
|
||||||
|
Implement explicit kernel mapping Ops over tensors. Kernel mappers add
|
||||||
|
Tensor-In-Tensor-Out (TITO) Ops to the TensorFlow graph. They can be used in
|
||||||
|
conjunction with other layers or ML models.
|
||||||
|
|
||||||
|
Sample usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_mapper = tf.contrib.kernel_methods.SomeKernelMapper(...)
|
||||||
|
out_tensor = kernel_mapper.map(in_tensor)
|
||||||
|
... # code that consumes out_tensor.
|
||||||
|
```
|
||||||
|
|
||||||
|
Currently, there is a [RandomFourierFeatureMapper]
|
||||||
|
(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/mappers/random_fourier_features.py) implemented that maps dense
|
||||||
|
input to dense output.
|
||||||
|
|
||||||
|
## Kernel-based Estimators
|
||||||
|
tf.contrib.learn Estimators that use kernel mappers internally to discover
|
||||||
|
non-linearities in the data. These canned estimators map their input features
|
||||||
|
using kernel mapper Ops and then apply linear models to the mapped
|
||||||
|
features. Combining kernel mappers with linear models and different loss
|
||||||
|
functions leads to a variety of models: linear and non-linear SVMs, linear
|
||||||
|
regression (with and without kernels) and (multinomial) logistic regression
|
||||||
|
(with and without kernels).
|
||||||
|
|
||||||
|
Currently there is a [KernelLinearClassifier]
|
||||||
|
(https://www.tensorflow.org/code/tensorflow/contrib/kernel_methods/python/kernel_estimators.py) implemented but more pre-packaged estimators
|
||||||
|
are on the way.
|
||||||
|
|
||||||
|
Sample usage:
|
||||||
|
|
||||||
|
```python
|
||||||
|
real_column_a = tf.contrib.layers.real_valued_column(name='real_column_a',...)
|
||||||
|
sparse_column_b = tf.contrib.layers.sparse_column_with_hash_bucket(...)
|
||||||
|
kernel_mappers = {real_column_a : [tf.contrib.kernel_methods.SomeKernelMapper(...)]}
|
||||||
|
optimizer = ...
|
||||||
|
|
||||||
|
kernel_classifier = tf.contrib.kernel_methods.KernelLinearClassifier(
|
||||||
|
feature_columns=[real_column_a, sparse_column_b],
|
||||||
|
model_dir=...,
|
||||||
|
optimizer=optimizer,
|
||||||
|
kernel_mappers=kernel_mappers)
|
||||||
|
|
||||||
|
# Construct input_fns
|
||||||
|
kernel_classifier.fit(...)
|
||||||
|
kernel_classifier.evaluate(...)
|
||||||
|
```
|
||||||
|
|
BIN
tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.png
Normal file
BIN
tensorflow/contrib/kernel_methods/g3doc/acc-vs-trn_time.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 18 KiB |
BIN
tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.png
Normal file
BIN
tensorflow/contrib/kernel_methods/g3doc/acc_vs_outdim.png
Normal file
Binary file not shown.
After Width: | Height: | Size: 19 KiB |
BIN
tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpg
Normal file
BIN
tensorflow/contrib/kernel_methods/g3doc/kernel_mapping.jpg
Normal file
Binary file not shown.
After Width: | Height: | Size: 7.2 KiB |
273
tensorflow/contrib/kernel_methods/g3doc/tutorial.md
Normal file
273
tensorflow/contrib/kernel_methods/g3doc/tutorial.md
Normal file
@ -0,0 +1,273 @@
|
|||||||
|
# Improving classification using explicit kernel methods
|
||||||
|
|
||||||
|
In this tutorial, we demonstrate how combining (explicit) kernel methods with
|
||||||
|
linear models can drastically increase the latters' quality of predictions
|
||||||
|
without significantly increasing training and inference times. Currently,
|
||||||
|
explicit kernel mappings are supported for dense features. Support for sparse
|
||||||
|
features is in the works.
|
||||||
|
|
||||||
|
We will use [tf.contrib.learn](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn) (TensorFlow's high-level Machine Learning API) Estimators for our ML models.
|
||||||
|
tf.contrib.learn API reduces the boilerplate code one needs to write for
|
||||||
|
configuring, training and evaluating models and will let us focus on the core
|
||||||
|
ideas. If you are not familiar with this API, [tf.contrib.learn Quickstart](https://www.tensorflow.org/get_started/tflearn) is a good place to start. We
|
||||||
|
will use MNIST, a widely-used dataset containing images of handwritten digits
|
||||||
|
(between 0 and 9). The tutorial consists of the following steps:
|
||||||
|
|
||||||
|
* Load and prepare MNIST data for classification.
|
||||||
|
* Construct a simple linear model, train it and evaluate it on the eval data.
|
||||||
|
* Replace the linear model with a kernelized linear model, re-train and
|
||||||
|
re-evaluate.
|
||||||
|
|
||||||
|
## Load and prepare MNIST data for classification
|
||||||
|
The first step is to prepare the data to be fed to the ML models. The following
|
||||||
|
utility command from tf.contrib.learn loads the MNIST dataset:
|
||||||
|
|
||||||
|
```python
|
||||||
|
data = tf.contrib.learn.datasets.mnist.load_mnist()
|
||||||
|
```
|
||||||
|
This loads the entire MNIST dataset (containing 70K samples) and splits it into
|
||||||
|
train, validation and test data with 55K, 5K and 10K samples respectively. Each
|
||||||
|
split contains one numpy array for images (with shape [sample_size, 784]) and
|
||||||
|
one for labels (with shape [sample_size, 1]). In this tutorial, we only use the
|
||||||
|
train and validation splits (to train and evaluate our models respectively).
|
||||||
|
|
||||||
|
In order to feed data to a tf.contrib.learn Estimator, it is helpful to convert
|
||||||
|
it to Tensors. For this, we will use an `input function` which adds Ops to the
|
||||||
|
TensorFlow graph that, when executed, create mini-batches of Tensors to be used
|
||||||
|
downstream. For more background on input functions, check
|
||||||
|
[Building Input Functions with tf.contrib.learn](https://www.tensorflow.org/get_started/input_fn). In this example, we will use the `tf.train.shuffle_batch` Op which,
|
||||||
|
besides converting numpy arrays to Tensors, allows us to specify the batch_size
|
||||||
|
and whether to randomize the input every time the input_fn Ops are executed
|
||||||
|
(randomization typically expedites convergence during training). The full code
|
||||||
|
for loading and preparing the data is shown in the snippet below. In this
|
||||||
|
example, we use mini-batches of size 256 for training and the entire sample (5K
|
||||||
|
entries) for evaluation. Feel free to experiment with different batch sizes.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import numpy as np
|
||||||
|
import tensorflow as tf
|
||||||
|
|
||||||
|
def get_input_fn(dataset_split, batch_size, capacity=10000, min_after_dequeue=3000):
|
||||||
|
|
||||||
|
def _input_fn():
|
||||||
|
images_batch, labels_batch = tf.train.shuffle_batch(
|
||||||
|
tensors=[dataset_split.images, dataset_split.labels.astype(np.int32)],
|
||||||
|
batch_size=batch_size,
|
||||||
|
capacity=capacity,
|
||||||
|
min_after_dequeue=min_after_dequeue,
|
||||||
|
enqueue_many=True,
|
||||||
|
num_threads=4)
|
||||||
|
features_map = {'images': images_batch}
|
||||||
|
return features_map, labels_batch
|
||||||
|
|
||||||
|
return _input_fn
|
||||||
|
|
||||||
|
data = tf.contrib.learn.datasets.mnist.load_mnist()
|
||||||
|
|
||||||
|
train_input_fn = get_input_fn(data.train, batch_size=256)
|
||||||
|
eval_input_fn = get_input_fn(data.validation, batch_size=5000)
|
||||||
|
|
||||||
|
```
|
||||||
|
|
||||||
|
## Training a simple linear model
|
||||||
|
We can now train a linear model over the MNIST dataset. We will use the [tf.contrib.learn.LinearClassifier](https://www.tensorflow.org/code/tensorflow/contrib/learn/python/learn/estimators/linear.py) estimator with 10 classes (representing the 10
|
||||||
|
digits). The input features form a 784-dimensional (dense) vector which can be
|
||||||
|
specified as follows:
|
||||||
|
|
||||||
|
```python
|
||||||
|
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
|
||||||
|
```
|
||||||
|
|
||||||
|
The full code for constructing, training and evaluating a LinearClassifier
|
||||||
|
estimator is shown below.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import time
|
||||||
|
|
||||||
|
# Specify the feature(s) to be used by the estimator.
|
||||||
|
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
|
||||||
|
estimator = tf.contrib.learn.LinearClassifier(feature_columns=[image_column], n_classes=10)
|
||||||
|
|
||||||
|
# Train.
|
||||||
|
start = time.time()
|
||||||
|
estimator.fit(input_fn=train_input_fn, steps=2000)
|
||||||
|
end = time.time()
|
||||||
|
print('Elapsed time: {} seconds'.format(end - start))
|
||||||
|
|
||||||
|
# Evaluate and report metrics.
|
||||||
|
eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
|
||||||
|
print(eval_metrics)
|
||||||
|
```
|
||||||
|
On eval data, the loss (i.e., the value of the objective function being
|
||||||
|
minimized during training) lies between **0.25 and 0.30** (depending on the
|
||||||
|
parameters used) while the accuracy of the classifier is approximately **92.5%**
|
||||||
|
(training is randomized so the exact loss and accuracy will vary). Also, the
|
||||||
|
training time is around 25 seconds (this will also vary based on the machine you
|
||||||
|
run the code on).
|
||||||
|
|
||||||
|
In addition to experimenting with the (training) batch size and the number of
|
||||||
|
training steps, there are a couple other parameters that can be tuned as well.
|
||||||
|
For instance, you can change the optimization method used to minimize the loss
|
||||||
|
by explicitly selecting another optimizer from the collection of
|
||||||
|
[available optimizers]
|
||||||
|
(https://www.tensorflow.org/code/tensorflow/python/training).
|
||||||
|
As an example, the following code constructs a LinearClassifer estimator that
|
||||||
|
uses the Follow-The-Regularized-Leader (FTRL) optimization strategy with a
|
||||||
|
specific learning rate and l2-regularization.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
optimizer = tf.train.FtrlOptimizer(learning_rate=5.0, l2_regularization_strength=1.0)
|
||||||
|
estimator = tf.contrib.learn.LinearClassifier(
|
||||||
|
feature_columns=[image_column], n_classes=10, optimizer=optimizer)
|
||||||
|
```
|
||||||
|
|
||||||
|
Regardless of the values of the parameters, the max accuracy a linear model can
|
||||||
|
achieve on this dataset caps at around **93%**.
|
||||||
|
|
||||||
|
## Using explicit kernel mappings with the linear model.
|
||||||
|
The relatively high error (~7%) of the linear model over MNIST indicates that
|
||||||
|
the input data is not linearly separable. We will use explicit kernel mappings
|
||||||
|
to reduce the classification error.
|
||||||
|
|
||||||
|
**Intuition:** The high-level idea is to use a non-linear map to transform the
|
||||||
|
input space to another feature space (of possibly higher dimension) where the
|
||||||
|
(transformed) features are (almost) linearly separable and then apply a linear
|
||||||
|
model on the mapped features. This is shown in the following figure:
|
||||||
|
|
||||||
|
<div style="text-align:center">
|
||||||
|
<img src="./kernel_mapping.png">
|
||||||
|
</div>
|
||||||
|
|
||||||
|
**Technical details overview:** In this example we will use **Random Fourier
|
||||||
|
Features** (introduced in the ["Random Features for Large-Scale Kernel Machines"]
|
||||||
|
(https://people.eecs.berkeley.edu/~brecht/papers/07.rah.rec.nips.pdf) paper by
|
||||||
|
Rahimi and Recht) to map the input data. Random Fourier Features map a vector
|
||||||
|
$$\mathbf{x} \in \mathbb{R}^d$$ to $$\mathbf{x'} \in \mathbb{R}^D$$ via the
|
||||||
|
following mapping:
|
||||||
|
|
||||||
|
$$
|
||||||
|
RFFM(\cdot): \mathbb{R}^d \to \mathbb{R}^D, \quad
|
||||||
|
RFFM(\mathbf{x}) = \cos(\mathbf{\Omega} \cdot \mathbf{x}+ \mathbf{b})
|
||||||
|
$$
|
||||||
|
|
||||||
|
where $$\mathbf{\Omega} \in \mathbb{R}^{D \times d}$$,
|
||||||
|
$$\mathbf{x} \in \mathbb{R}^d,$$ $$\mathbf{b} \in \mathbb{R}^D$$ and cosine is
|
||||||
|
applied elementwise.
|
||||||
|
|
||||||
|
In this example, the entries of $$\mathbf{\Omega}$$ and $$\mathbf{b}$$ are
|
||||||
|
sampled from distributions such that the mapping satisfies the following
|
||||||
|
property:
|
||||||
|
|
||||||
|
$$
|
||||||
|
RFFM(\mathbf{x})^T \cdot RFFM(\mathbf{y}) \approx
|
||||||
|
e^{-\frac{\|\mathbf{x} - \mathbf{y}\|^2}{2 \sigma^2}}
|
||||||
|
$$
|
||||||
|
|
||||||
|
The right-hand-side quantity of the expression above is known as the RBF (or
|
||||||
|
Gaussian) kernel function. This function is one of the most-widely used kernel
|
||||||
|
functions in Machine Learning and measures (implicitly) similarity in a
|
||||||
|
different (much higher dimensional) space than the original one. See
|
||||||
|
[Radial basis function kernel](https://en.wikipedia.org/wiki/Radial_basis_function_kernel)
|
||||||
|
for more details.
|
||||||
|
|
||||||
|
**Kernel Classifier:** `tf.contrib.kernel_methods.KernelLinearClassifier` is a
|
||||||
|
pre-packaged `tf.contrib.learn` estimator that combines the power of explicit
|
||||||
|
kernel mappings with linear models. Its API is very similar to that of the
|
||||||
|
LinearClassifier with the additional ability to specify a list of explicit
|
||||||
|
kernel mappings to be apply to each feature used by the classifier. The
|
||||||
|
following code snippet demonstrates how to replace LinearClassifier with
|
||||||
|
KernelLinearClassifier.
|
||||||
|
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Specify the feature(s) to be used by the estimator. This is identical to the
|
||||||
|
# code used for the LinearClassifier.
|
||||||
|
image_column = tf.contrib.layers.real_valued_column('images', dimension=784)
|
||||||
|
optimizer = tf.train.FtrlOptimizer(
|
||||||
|
learning_rate=50.0, l2_regularization_strength=0.001)
|
||||||
|
|
||||||
|
|
||||||
|
kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
|
||||||
|
input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
|
||||||
|
kernel_mappers = {image_column: [kernel_mapper]}
|
||||||
|
estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
|
||||||
|
n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
|
||||||
|
|
||||||
|
# Train.
|
||||||
|
start = time.time()
|
||||||
|
estimator.fit(input_fn=train_input_fn, steps=2000)
|
||||||
|
end = time.time()
|
||||||
|
print('Elapsed time: {} seconds'.format(end - start))
|
||||||
|
|
||||||
|
# Evaluate and report metrics.
|
||||||
|
eval_metrics = estimator.evaluate(input_fn=eval_input_fn, steps=1)
|
||||||
|
print(eval_metrics)
|
||||||
|
```
|
||||||
|
The only additional parameter passed to `KernelLinearClassifier` is a dictionary
|
||||||
|
from feature_columns to a list of kernel mappings to be applied to the
|
||||||
|
corresponding feature column. In this example, the lines
|
||||||
|
|
||||||
|
```python
|
||||||
|
kernel_mapper = tf.contrib.kernel_methods.RandomFourierFeatureMapper(
|
||||||
|
input_dim=784, output_dim=2000, stddev=5.0, name='rffm')
|
||||||
|
kernel_mappers = {image_column: [kernel_mapper]}
|
||||||
|
estimator = tf.contrib.kernel_methods.KernelLinearClassifier(
|
||||||
|
n_classes=10, optimizer=optimizer, kernel_mappers=kernel_mappers)
|
||||||
|
```
|
||||||
|
instruct the classifier to first map the initial 784-dimensional images to
|
||||||
|
2000-dimensional vectors using random Fourier features and then learn a linear
|
||||||
|
model on the transformed vectors. Note that, besides the output dimension, there
|
||||||
|
is one more parameter (stddev) involved. This parameter is the standard
|
||||||
|
deviation ($$\sigma$$) of the approximated RBF kernel and controls the
|
||||||
|
similarity measure used in classification. This parameter is typically
|
||||||
|
determined via hyperparameter tuning.
|
||||||
|
|
||||||
|
Running the code above yields a loss of approximately **0.10** while the
|
||||||
|
accuracy is increased to approximately **97%** on eval data (an increase of 4%
|
||||||
|
over the plain linear model). The training time hovers around 35 seconds. We can
|
||||||
|
increase the accuracy even more, by increasing the output dimension of the
|
||||||
|
mapping and tuning the standard deviation even more.
|
||||||
|
|
||||||
|
**On the role of stddev:** The classification quality is very sensitive to the
|
||||||
|
value of the stddev parameter used to define the similarity measure between the
|
||||||
|
pairs of input features. The following table shows the accuracy of the
|
||||||
|
classifier on the eval data for different values of stddev (for all experiments
|
||||||
|
the output dimension was fixed to 3000). The optimal value is stddev=5.0. Notice
|
||||||
|
how too small or too high stddev values can dramatically decrease the accuracy
|
||||||
|
of the classification.
|
||||||
|
|
||||||
|
stddev | eval accuracy
|
||||||
|
:----- | :------------
|
||||||
|
1.0 | 0.1362
|
||||||
|
2.0 | 0.4764
|
||||||
|
4.0 | 0.9654
|
||||||
|
5.0 | 0.9766
|
||||||
|
8.0 | 0.9714
|
||||||
|
16.0 | 0.8878
|
||||||
|
|
||||||
|
**On the role of the output dimension:** Intuitively, the larger the output
|
||||||
|
dimension of the mapping, the closer the inner product of two mapped vectors
|
||||||
|
approximates the kernel which typically translates to better classification
|
||||||
|
accuracy. Another way to think about this is that the output dimension equals
|
||||||
|
the number of weights of the linear model (the larger this dimension, the larger
|
||||||
|
the "degrees of freedom" of the model). However, after a certain threshold,
|
||||||
|
higher output dimensions increase the accuracy by very little (while still
|
||||||
|
increasing the training time). This is shown in the following 2 Figures which
|
||||||
|
depict the eval accuracy as a function of the output dimension and the training
|
||||||
|
time respectively.
|
||||||
|
|
||||||
|
 
|
||||||
|
|
||||||
|
|
||||||
|
## Explicit Kernel Mappings: summary and practical tips
|
||||||
|
* Explicit kernel mappings combine the predictive power of non-linear models
|
||||||
|
with the scalability of linear models.
|
||||||
|
* Random Fourier Features can be particularly effective for datasets with dense
|
||||||
|
features.
|
||||||
|
* The parameters of the kernel mapping are often data-dependent. Model quality
|
||||||
|
can be very sensitive to these parameters. Use hyperparameter tuning to find the
|
||||||
|
optimal values.
|
||||||
|
* If you have multiple numerical features, concatinate them into a single
|
||||||
|
multi-dimensional one and apply the kernel mapping to the concatenated vector.
|
||||||
|
|
@ -121,7 +121,7 @@ def embed_sequence(ids,
|
|||||||
`Tensor` of `[batch_size, doc_length, embed_dim]` with embedded sequences.
|
`Tensor` of `[batch_size, doc_length, embed_dim]` with embedded sequences.
|
||||||
|
|
||||||
Raises:
|
Raises:
|
||||||
ValueError: if `embed_dim` or `vocab_size` are not specified when
|
ValueError: if `embed_dim` or `vocab_size` are not specified when
|
||||||
`reuse` is `None` or `False`.
|
`reuse` is `None` or `False`.
|
||||||
"""
|
"""
|
||||||
if not (reuse or (vocab_size and embed_dim)):
|
if not (reuse or (vocab_size and embed_dim)):
|
||||||
|
@ -131,21 +131,27 @@ import math
|
|||||||
import six
|
import six
|
||||||
|
|
||||||
from tensorflow.contrib import lookup
|
from tensorflow.contrib import lookup
|
||||||
|
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
||||||
from tensorflow.contrib.framework.python.framework import experimental
|
from tensorflow.contrib.framework.python.framework import experimental
|
||||||
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
|
from tensorflow.contrib.layers.python.layers import embedding_ops
|
||||||
from tensorflow.contrib.layers.python.layers import layers
|
from tensorflow.contrib.layers.python.layers import layers
|
||||||
from tensorflow.contrib.layers.python.ops import bucketization_op
|
from tensorflow.contrib.layers.python.ops import bucketization_op
|
||||||
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
from tensorflow.contrib.layers.python.ops import sparse_feature_cross_op
|
||||||
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
|
from tensorflow.contrib.layers.python.ops import sparse_ops as contrib_sparse_ops
|
||||||
from tensorflow.python.feature_column import feature_column as fc_core
|
from tensorflow.python.feature_column import feature_column as fc_core
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_py
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
|
from tensorflow.python.ops import resource_variable_ops
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops import string_ops
|
from tensorflow.python.ops import string_ops
|
||||||
|
from tensorflow.python.ops import variables
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import deprecation
|
from tensorflow.python.util import deprecation
|
||||||
|
|
||||||
@ -291,11 +297,13 @@ class _FeatureColumn(object):
|
|||||||
|
|
||||||
|
|
||||||
# TODO(b/30410315): Support warm starting in all feature columns.
|
# TODO(b/30410315): Support warm starting in all feature columns.
|
||||||
class _SparseColumn(_FeatureColumn,
|
class _SparseColumn(
|
||||||
collections.namedtuple("_SparseColumn",
|
_FeatureColumn,
|
||||||
["column_name", "is_integerized",
|
fc_core._CategoricalColumn, # pylint: disable=protected-access
|
||||||
"bucket_size", "lookup_config",
|
collections.namedtuple("_SparseColumn", [
|
||||||
"combiner", "dtype"])):
|
"column_name", "is_integerized", "bucket_size", "lookup_config",
|
||||||
|
"combiner", "dtype"
|
||||||
|
])):
|
||||||
"""Represents a sparse feature column also known as categorical features.
|
"""Represents a sparse feature column also known as categorical features.
|
||||||
|
|
||||||
Instances of this class are immutable. A sparse column means features are
|
Instances of this class are immutable. A sparse column means features are
|
||||||
@ -426,9 +434,8 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
initializer=init_ops.zeros_initializer(),
|
initializer=init_ops.zeros_initializer(),
|
||||||
combiner=self.combiner)
|
combiner=self.combiner)
|
||||||
|
|
||||||
def _get_input_sparse_tensor(self, columns_to_tensors):
|
def _get_input_sparse_tensor(self, input_tensor):
|
||||||
"""Looks up the input tensor for transformation and sparsify it if dense."""
|
"""sparsify input_tensor if dense."""
|
||||||
input_tensor = columns_to_tensors[self.name]
|
|
||||||
if not isinstance(input_tensor, sparse_tensor_py.SparseTensor):
|
if not isinstance(input_tensor, sparse_tensor_py.SparseTensor):
|
||||||
# To avoid making any assumptions about which values are to be ignored,
|
# To avoid making any assumptions about which values are to be ignored,
|
||||||
# we set ignore_value to -1 for numeric tensors to avoid excluding valid
|
# we set ignore_value to -1 for numeric tensors to avoid excluding valid
|
||||||
@ -455,18 +462,44 @@ class _SparseColumn(_FeatureColumn,
|
|||||||
format(self.name, other_column.name))
|
format(self.name, other_column.name))
|
||||||
return compatible
|
return compatible
|
||||||
|
|
||||||
|
@abc.abstractmethod
|
||||||
|
def _do_transform(self, input_tensor):
|
||||||
|
pass
|
||||||
|
|
||||||
|
def insert_transformed_feature(self, columns_to_tensors):
|
||||||
|
"""Handles sparse column to id conversion."""
|
||||||
|
input_tensor = self._get_input_sparse_tensor(columns_to_tensors[self.name])
|
||||||
|
columns_to_tensors[self] = self._do_transform(input_tensor)
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
input_tensor = self._get_input_sparse_tensor(inputs.get(self.name))
|
||||||
|
return self._do_transform(input_tensor)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_buckets(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||||
|
trainable=None):
|
||||||
|
del weight_collections
|
||||||
|
del trainable
|
||||||
|
input_tensor = inputs.get(self)
|
||||||
|
return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
|
||||||
|
self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
|
||||||
|
|
||||||
|
|
||||||
class _SparseColumnIntegerized(_SparseColumn):
|
class _SparseColumnIntegerized(_SparseColumn):
|
||||||
"""See `sparse_column_with_integerized_feature`."""
|
"""See `sparse_column_with_integerized_feature`."""
|
||||||
|
|
||||||
def insert_transformed_feature(self, columns_to_tensors):
|
def _do_transform(self, input_tensor):
|
||||||
"""Handles sparse column to id conversion."""
|
|
||||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
|
|
||||||
|
|
||||||
sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size,
|
sparse_id_values = math_ops.mod(input_tensor.values, self.bucket_size,
|
||||||
name="mod")
|
name="mod")
|
||||||
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
|
||||||
input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
|
input_tensor.dense_shape)
|
||||||
|
|
||||||
|
|
||||||
def sparse_column_with_integerized_feature(column_name,
|
def sparse_column_with_integerized_feature(column_name,
|
||||||
@ -517,10 +550,7 @@ def sparse_column_with_integerized_feature(column_name,
|
|||||||
class _SparseColumnHashed(_SparseColumn):
|
class _SparseColumnHashed(_SparseColumn):
|
||||||
"""See `sparse_column_with_hash_bucket`."""
|
"""See `sparse_column_with_hash_bucket`."""
|
||||||
|
|
||||||
def insert_transformed_feature(self, columns_to_tensors):
|
def _do_transform(self, input_tensor):
|
||||||
"""Handles sparse column to id conversion."""
|
|
||||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
|
|
||||||
|
|
||||||
if self.dtype.is_integer:
|
if self.dtype.is_integer:
|
||||||
sparse_values = string_ops.as_string(input_tensor.values)
|
sparse_values = string_ops.as_string(input_tensor.values)
|
||||||
else:
|
else:
|
||||||
@ -528,8 +558,8 @@ class _SparseColumnHashed(_SparseColumn):
|
|||||||
|
|
||||||
sparse_id_values = string_ops.string_to_hash_bucket_fast(
|
sparse_id_values = string_ops.string_to_hash_bucket_fast(
|
||||||
sparse_values, self.bucket_size, name="lookup")
|
sparse_values, self.bucket_size, name="lookup")
|
||||||
columns_to_tensors[self] = sparse_tensor_py.SparseTensor(
|
return sparse_tensor_py.SparseTensor(input_tensor.indices, sparse_id_values,
|
||||||
input_tensor.indices, sparse_id_values, input_tensor.dense_shape)
|
input_tensor.dense_shape)
|
||||||
|
|
||||||
|
|
||||||
def sparse_column_with_hash_bucket(column_name,
|
def sparse_column_with_hash_bucket(column_name,
|
||||||
@ -572,16 +602,13 @@ def sparse_column_with_hash_bucket(column_name,
|
|||||||
class _SparseColumnKeys(_SparseColumn):
|
class _SparseColumnKeys(_SparseColumn):
|
||||||
"""See `sparse_column_with_keys`."""
|
"""See `sparse_column_with_keys`."""
|
||||||
|
|
||||||
def insert_transformed_feature(self, columns_to_tensors):
|
def _do_transform(self, input_tensor):
|
||||||
"""Handles sparse column to id conversion."""
|
|
||||||
input_tensor = self._get_input_sparse_tensor(columns_to_tensors)
|
|
||||||
|
|
||||||
table = lookup.index_table_from_tensor(
|
table = lookup.index_table_from_tensor(
|
||||||
mapping=tuple(self.lookup_config.keys),
|
mapping=tuple(self.lookup_config.keys),
|
||||||
default_value=self.lookup_config.default_value,
|
default_value=self.lookup_config.default_value,
|
||||||
dtype=self.dtype,
|
dtype=self.dtype,
|
||||||
name="lookup")
|
name="lookup")
|
||||||
columns_to_tensors[self] = table.lookup(input_tensor)
|
return table.lookup(input_tensor)
|
||||||
|
|
||||||
|
|
||||||
def sparse_column_with_keys(
|
def sparse_column_with_keys(
|
||||||
@ -621,9 +648,7 @@ def sparse_column_with_keys(
|
|||||||
class _SparseColumnVocabulary(_SparseColumn):
|
class _SparseColumnVocabulary(_SparseColumn):
|
||||||
"""See `sparse_column_with_vocabulary_file`."""
|
"""See `sparse_column_with_vocabulary_file`."""
|
||||||
|
|
||||||
def insert_transformed_feature(self, columns_to_tensors):
|
def _do_transform(self, st):
|
||||||
"""Handles sparse column to id conversion."""
|
|
||||||
st = self._get_input_sparse_tensor(columns_to_tensors)
|
|
||||||
if self.dtype.is_integer:
|
if self.dtype.is_integer:
|
||||||
sparse_string_values = string_ops.as_string(st.values)
|
sparse_string_values = string_ops.as_string(st.values)
|
||||||
sparse_string_tensor = sparse_tensor_py.SparseTensor(st.indices,
|
sparse_string_tensor = sparse_tensor_py.SparseTensor(st.indices,
|
||||||
@ -638,7 +663,7 @@ class _SparseColumnVocabulary(_SparseColumn):
|
|||||||
vocab_size=self.lookup_config.vocab_size,
|
vocab_size=self.lookup_config.vocab_size,
|
||||||
default_value=self.lookup_config.default_value,
|
default_value=self.lookup_config.default_value,
|
||||||
name=self.name + "_lookup")
|
name=self.name + "_lookup")
|
||||||
columns_to_tensors[self] = table.lookup(sparse_string_tensor)
|
return table.lookup(sparse_string_tensor)
|
||||||
|
|
||||||
|
|
||||||
def sparse_column_with_vocabulary_file(column_name,
|
def sparse_column_with_vocabulary_file(column_name,
|
||||||
@ -694,9 +719,12 @@ def sparse_column_with_vocabulary_file(column_name,
|
|||||||
dtype=dtype)
|
dtype=dtype)
|
||||||
|
|
||||||
|
|
||||||
class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
class _WeightedSparseColumn(
|
||||||
"_WeightedSparseColumn",
|
_FeatureColumn,
|
||||||
["sparse_id_column", "weight_column_name", "dtype"])):
|
fc_core._CategoricalColumn, # pylint: disable=protected-access
|
||||||
|
collections.namedtuple("_WeightedSparseColumn",
|
||||||
|
["sparse_id_column", "weight_column_name",
|
||||||
|
"dtype"])):
|
||||||
"""See `weighted_sparse_column`."""
|
"""See `weighted_sparse_column`."""
|
||||||
|
|
||||||
def __new__(cls, sparse_id_column, weight_column_name, dtype):
|
def __new__(cls, sparse_id_column, weight_column_name, dtype):
|
||||||
@ -725,22 +753,6 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
"""Returns a string which will be used as a key when we do sorting."""
|
"""Returns a string which will be used as a key when we do sorting."""
|
||||||
return "{}".format(self)
|
return "{}".format(self)
|
||||||
|
|
||||||
def insert_transformed_feature(self, columns_to_tensors):
|
|
||||||
"""Inserts a tuple with the id and weight tensors."""
|
|
||||||
if self.sparse_id_column not in columns_to_tensors:
|
|
||||||
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
|
|
||||||
|
|
||||||
weight_tensor = columns_to_tensors[self.weight_column_name]
|
|
||||||
if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
|
|
||||||
# The weight tensor can be a regular Tensor. In such case, sparsify it.
|
|
||||||
weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
|
|
||||||
if not self.dtype.is_floating:
|
|
||||||
weight_tensor = math_ops.to_float(weight_tensor)
|
|
||||||
columns_to_tensors[self] = tuple([
|
|
||||||
columns_to_tensors[self.sparse_id_column],
|
|
||||||
weight_tensor
|
|
||||||
])
|
|
||||||
|
|
||||||
def id_tensor(self, input_tensor):
|
def id_tensor(self, input_tensor):
|
||||||
"""Returns the id tensor from the given transformed input_tensor."""
|
"""Returns the id tensor from the given transformed input_tensor."""
|
||||||
return input_tensor[0]
|
return input_tensor[0]
|
||||||
@ -768,6 +780,43 @@ class _WeightedSparseColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
initializer=init_ops.zeros_initializer(),
|
initializer=init_ops.zeros_initializer(),
|
||||||
combiner=self.sparse_id_column.combiner)
|
combiner=self.sparse_id_column.combiner)
|
||||||
|
|
||||||
|
def _do_transform(self, id_tensor, weight_tensor):
|
||||||
|
if not isinstance(weight_tensor, sparse_tensor_py.SparseTensor):
|
||||||
|
# The weight tensor can be a regular Tensor. In such case, sparsify it.
|
||||||
|
weight_tensor = contrib_sparse_ops.dense_to_sparse_tensor(weight_tensor)
|
||||||
|
if not self.dtype.is_floating:
|
||||||
|
weight_tensor = math_ops.to_float(weight_tensor)
|
||||||
|
return tuple([id_tensor, weight_tensor])
|
||||||
|
|
||||||
|
def insert_transformed_feature(self, columns_to_tensors):
|
||||||
|
"""Inserts a tuple with the id and weight tensors."""
|
||||||
|
if self.sparse_id_column not in columns_to_tensors:
|
||||||
|
self.sparse_id_column.insert_transformed_feature(columns_to_tensors)
|
||||||
|
|
||||||
|
weight_tensor = columns_to_tensors[self.weight_column_name]
|
||||||
|
columns_to_tensors[self] = self._do_transform(
|
||||||
|
columns_to_tensors[self.sparse_id_column], weight_tensor)
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
return self._do_transform(
|
||||||
|
inputs.get(self.sparse_id_column), inputs.get(self.weight_column_name))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_buckets(self):
|
||||||
|
return self.length
|
||||||
|
|
||||||
|
def _get_sparse_tensors(self, inputs, weight_collections=None,
|
||||||
|
trainable=None):
|
||||||
|
del weight_collections
|
||||||
|
del trainable
|
||||||
|
input_tensor = inputs.get(self)
|
||||||
|
return fc_core._CategoricalColumn.IdWeightPair( # pylint: disable=protected-access
|
||||||
|
self.id_tensor(input_tensor), self.weight_tensor(input_tensor))
|
||||||
|
|
||||||
|
|
||||||
def weighted_sparse_column(sparse_id_column,
|
def weighted_sparse_column(sparse_id_column,
|
||||||
weight_column_name,
|
weight_column_name,
|
||||||
@ -815,9 +864,10 @@ def weighted_sparse_column(sparse_id_column,
|
|||||||
return _WeightedSparseColumn(sparse_id_column, weight_column_name, dtype)
|
return _WeightedSparseColumn(sparse_id_column, weight_column_name, dtype)
|
||||||
|
|
||||||
|
|
||||||
class _OneHotColumn(_FeatureColumn,
|
class _OneHotColumn(
|
||||||
collections.namedtuple("_OneHotColumn",
|
_FeatureColumn,
|
||||||
["sparse_id_column"])):
|
fc_core._DenseColumn, # pylint: disable=protected-access
|
||||||
|
collections.namedtuple("_OneHotColumn", ["sparse_id_column"])):
|
||||||
"""Represents a one-hot column for use in deep networks.
|
"""Represents a one-hot column for use in deep networks.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -897,12 +947,31 @@ class _OneHotColumn(_FeatureColumn,
|
|||||||
return math_ops.reduce_sum(
|
return math_ops.reduce_sum(
|
||||||
one_hot_id_tensor, reduction_indices=[output_rank - 1])
|
one_hot_id_tensor, reduction_indices=[output_rank - 1])
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _variable_shape(self):
|
||||||
|
return tensor_shape.TensorShape((self.length))
|
||||||
|
|
||||||
class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||||
"_EmbeddingColumn",
|
del weight_collections
|
||||||
["sparse_id_column", "dimension", "combiner", "initializer",
|
del trainable
|
||||||
"ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name",
|
return inputs.get(self)
|
||||||
"shared_vocab_size", "max_norm", "trainable"])):
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
return self._to_dnn_input_layer(inputs.get(self.sparse_id_column))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
|
||||||
|
class _EmbeddingColumn(
|
||||||
|
_FeatureColumn,
|
||||||
|
fc_core._DenseColumn, # pylint: disable=protected-access
|
||||||
|
collections.namedtuple("_EmbeddingColumn", [
|
||||||
|
"sparse_id_column", "dimension", "combiner", "initializer",
|
||||||
|
"ckpt_to_load_from", "tensor_name_in_ckpt", "shared_embedding_name",
|
||||||
|
"shared_vocab_size", "max_norm", "trainable"
|
||||||
|
])):
|
||||||
"""Represents an embedding column.
|
"""Represents an embedding column.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
@ -1027,6 +1096,139 @@ class _EmbeddingColumn(_FeatureColumn, collections.namedtuple(
|
|||||||
raise ValueError("Column {} is not supported in linear models. "
|
raise ValueError("Column {} is not supported in linear models. "
|
||||||
"Please use sparse_column.".format(self))
|
"Please use sparse_column.".format(self))
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _variable_shape(self):
|
||||||
|
return tensor_shape.TensorShape((self.dimension))
|
||||||
|
|
||||||
|
def _get_dense_tensor(self, inputs, weight_collections=None, trainable=None):
|
||||||
|
return _embeddings_from_arguments(
|
||||||
|
self,
|
||||||
|
self._deep_embedding_lookup_arguments(inputs.get(self)),
|
||||||
|
weight_collections, trainable)
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
return inputs.get(self.sparse_id_column)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return self.config
|
||||||
|
|
||||||
|
|
||||||
|
def _is_variable(v):
|
||||||
|
"""Returns true if `v` is a variable."""
|
||||||
|
return isinstance(v, (variables.Variable,
|
||||||
|
resource_variable_ops.ResourceVariable))
|
||||||
|
|
||||||
|
|
||||||
|
def _embeddings_from_arguments(column,
|
||||||
|
args,
|
||||||
|
weight_collections,
|
||||||
|
trainable,
|
||||||
|
output_rank=2):
|
||||||
|
"""Returns embeddings for a column based on the computed arguments.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
column: the column name.
|
||||||
|
args: the _DeepEmbeddingLookupArguments for this column.
|
||||||
|
weight_collections: collections to store weights in.
|
||||||
|
trainable: whether these embeddings should be trainable.
|
||||||
|
output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
|
||||||
|
be combined to produce the desired rank.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
the embeddings.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if not possible to create.
|
||||||
|
"""
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
|
||||||
|
weight_tensor = None
|
||||||
|
if args.weight_tensor is not None:
|
||||||
|
weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
# This option is only enabled for scattered_embedding_column.
|
||||||
|
if args.hash_key:
|
||||||
|
embeddings = contrib_variables.model_variable(
|
||||||
|
name="weights",
|
||||||
|
shape=[args.vocab_size],
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
initializer=args.initializer,
|
||||||
|
trainable=(trainable and args.trainable),
|
||||||
|
collections=weight_collections)
|
||||||
|
|
||||||
|
return embedding_ops.scattered_embedding_lookup_sparse(
|
||||||
|
embeddings,
|
||||||
|
input_tensor,
|
||||||
|
args.dimension,
|
||||||
|
hash_key=args.hash_key,
|
||||||
|
combiner=args.combiner,
|
||||||
|
name="lookup")
|
||||||
|
|
||||||
|
if args.shared_embedding_name is not None:
|
||||||
|
shared_embedding_collection_name = (
|
||||||
|
"SHARED_EMBEDDING_COLLECTION_" + args.shared_embedding_name.upper())
|
||||||
|
graph = ops.get_default_graph()
|
||||||
|
shared_embedding_collection = (
|
||||||
|
graph.get_collection_ref(shared_embedding_collection_name))
|
||||||
|
shape = [args.vocab_size, args.dimension]
|
||||||
|
if shared_embedding_collection:
|
||||||
|
if len(shared_embedding_collection) > 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Collection %s can only contain one "
|
||||||
|
"(partitioned) variable." % shared_embedding_collection_name)
|
||||||
|
else:
|
||||||
|
embeddings = shared_embedding_collection[0]
|
||||||
|
if embeddings.get_shape() != shape:
|
||||||
|
raise ValueError(
|
||||||
|
"The embedding variable with name {} already "
|
||||||
|
"exists, but its shape does not match required "
|
||||||
|
"embedding shape here. Please make sure to use "
|
||||||
|
"different shared_embedding_name for different "
|
||||||
|
"shared embeddings.".format(args.shared_embedding_name))
|
||||||
|
else:
|
||||||
|
embeddings = contrib_variables.model_variable(
|
||||||
|
name=args.shared_embedding_name,
|
||||||
|
shape=shape,
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
initializer=args.initializer,
|
||||||
|
trainable=(trainable and args.trainable),
|
||||||
|
collections=weight_collections)
|
||||||
|
graph.add_to_collection(shared_embedding_collection_name, embeddings)
|
||||||
|
else:
|
||||||
|
embeddings = contrib_variables.model_variable(
|
||||||
|
name="weights",
|
||||||
|
shape=[args.vocab_size, args.dimension],
|
||||||
|
dtype=dtypes.float32,
|
||||||
|
initializer=args.initializer,
|
||||||
|
trainable=(trainable and args.trainable),
|
||||||
|
collections=weight_collections)
|
||||||
|
|
||||||
|
if _is_variable(embeddings):
|
||||||
|
embeddings = [embeddings]
|
||||||
|
else:
|
||||||
|
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
_maybe_restore_from_checkpoint(column._checkpoint_path(), embeddings)
|
||||||
|
return embedding_ops.safe_embedding_lookup_sparse(
|
||||||
|
embeddings,
|
||||||
|
input_tensor,
|
||||||
|
sparse_weights=weight_tensor,
|
||||||
|
combiner=args.combiner,
|
||||||
|
name=column.name + "weights",
|
||||||
|
max_norm=args.max_norm)
|
||||||
|
|
||||||
|
|
||||||
|
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
|
||||||
|
if checkpoint_path is not None:
|
||||||
|
path, tensor_name = checkpoint_path
|
||||||
|
weights_to_restore = variable
|
||||||
|
if len(variable) == 1:
|
||||||
|
weights_to_restore = variable[0]
|
||||||
|
checkpoint_utils.init_from_checkpoint(path,
|
||||||
|
{tensor_name: weights_to_restore})
|
||||||
|
|
||||||
|
|
||||||
def one_hot_column(sparse_id_column):
|
def one_hot_column(sparse_id_column):
|
||||||
"""Creates an `_OneHotColumn` for a one-hot or multi-hot repr in a DNN.
|
"""Creates an `_OneHotColumn` for a one-hot or multi-hot repr in a DNN.
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import functools
|
import functools
|
||||||
|
|
||||||
from tensorflow.contrib.framework.python.framework import checkpoint_utils
|
|
||||||
from tensorflow.contrib.framework.python.framework import experimental
|
from tensorflow.contrib.framework.python.framework import experimental
|
||||||
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
from tensorflow.contrib.framework.python.ops import variables as contrib_variables
|
||||||
from tensorflow.contrib.layers.python.layers import embedding_ops
|
from tensorflow.contrib.layers.python.layers import embedding_ops
|
||||||
@ -34,118 +33,12 @@ from tensorflow.python.ops import init_ops
|
|||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import nn_ops
|
from tensorflow.python.ops import nn_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import resource_variable_ops
|
|
||||||
from tensorflow.python.ops import sparse_ops
|
from tensorflow.python.ops import sparse_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
from tensorflow.python.ops import variables
|
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
from tensorflow.python.util import nest
|
from tensorflow.python.util import nest
|
||||||
|
|
||||||
|
|
||||||
def _is_variable(v):
|
|
||||||
"""Returns true if `v` is a variable."""
|
|
||||||
return isinstance(v, (variables.Variable,
|
|
||||||
resource_variable_ops.ResourceVariable))
|
|
||||||
|
|
||||||
|
|
||||||
def _embeddings_from_arguments(column,
|
|
||||||
args,
|
|
||||||
weight_collections,
|
|
||||||
trainable,
|
|
||||||
output_rank=2):
|
|
||||||
"""Returns embeddings for a column based on the computed arguments.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
column: the column name.
|
|
||||||
args: the _DeepEmbeddingLookupArguments for this column.
|
|
||||||
weight_collections: collections to store weights in.
|
|
||||||
trainable: whether these embeddings should be trainable.
|
|
||||||
output_rank: the desired rank of the returned `Tensor`. Inner dimensions will
|
|
||||||
be combined to produce the desired rank.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
the embeddings.
|
|
||||||
|
|
||||||
Raises:
|
|
||||||
ValueError: if not possible to create.
|
|
||||||
"""
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
input_tensor = layers._inner_flatten(args.input_tensor, output_rank)
|
|
||||||
weight_tensor = None
|
|
||||||
if args.weight_tensor is not None:
|
|
||||||
weight_tensor = layers._inner_flatten(args.weight_tensor, output_rank)
|
|
||||||
# pylint: enable=protected-access
|
|
||||||
|
|
||||||
# This option is only enabled for scattered_embedding_column.
|
|
||||||
if args.hash_key:
|
|
||||||
embeddings = contrib_variables.model_variable(
|
|
||||||
name='weights',
|
|
||||||
shape=[args.vocab_size],
|
|
||||||
dtype=dtypes.float32,
|
|
||||||
initializer=args.initializer,
|
|
||||||
trainable=(trainable and args.trainable),
|
|
||||||
collections=weight_collections)
|
|
||||||
|
|
||||||
return embedding_ops.scattered_embedding_lookup_sparse(
|
|
||||||
embeddings, input_tensor, args.dimension,
|
|
||||||
hash_key=args.hash_key,
|
|
||||||
combiner=args.combiner, name='lookup')
|
|
||||||
|
|
||||||
if args.shared_embedding_name is not None:
|
|
||||||
shared_embedding_collection_name = (
|
|
||||||
'SHARED_EMBEDDING_COLLECTION_' + args.shared_embedding_name.upper())
|
|
||||||
graph = ops.get_default_graph()
|
|
||||||
shared_embedding_collection = (
|
|
||||||
graph.get_collection_ref(shared_embedding_collection_name))
|
|
||||||
shape = [args.vocab_size, args.dimension]
|
|
||||||
if shared_embedding_collection:
|
|
||||||
if len(shared_embedding_collection) > 1:
|
|
||||||
raise ValueError('Collection %s can only contain one '
|
|
||||||
'(partitioned) variable.'
|
|
||||||
% shared_embedding_collection_name)
|
|
||||||
else:
|
|
||||||
embeddings = shared_embedding_collection[0]
|
|
||||||
if embeddings.get_shape() != shape:
|
|
||||||
raise ValueError('The embedding variable with name {} already '
|
|
||||||
'exists, but its shape does not match required '
|
|
||||||
'embedding shape here. Please make sure to use '
|
|
||||||
'different shared_embedding_name for different '
|
|
||||||
'shared embeddings.'.format(
|
|
||||||
args.shared_embedding_name))
|
|
||||||
else:
|
|
||||||
embeddings = contrib_variables.model_variable(
|
|
||||||
name=args.shared_embedding_name,
|
|
||||||
shape=shape,
|
|
||||||
dtype=dtypes.float32,
|
|
||||||
initializer=args.initializer,
|
|
||||||
trainable=(trainable and args.trainable),
|
|
||||||
collections=weight_collections)
|
|
||||||
graph.add_to_collection(shared_embedding_collection_name, embeddings)
|
|
||||||
else:
|
|
||||||
embeddings = contrib_variables.model_variable(
|
|
||||||
name='weights',
|
|
||||||
shape=[args.vocab_size, args.dimension],
|
|
||||||
dtype=dtypes.float32,
|
|
||||||
initializer=args.initializer,
|
|
||||||
trainable=(trainable and args.trainable),
|
|
||||||
collections=weight_collections)
|
|
||||||
|
|
||||||
if _is_variable(embeddings):
|
|
||||||
embeddings = [embeddings]
|
|
||||||
else:
|
|
||||||
embeddings = embeddings._get_variable_list() # pylint: disable=protected-access
|
|
||||||
# pylint: disable=protected-access
|
|
||||||
_maybe_restore_from_checkpoint(
|
|
||||||
column._checkpoint_path(), embeddings)
|
|
||||||
return embedding_ops.safe_embedding_lookup_sparse(
|
|
||||||
embeddings,
|
|
||||||
input_tensor,
|
|
||||||
sparse_weights=weight_tensor,
|
|
||||||
combiner=args.combiner,
|
|
||||||
name=column.name + 'weights',
|
|
||||||
max_norm=args.max_norm)
|
|
||||||
|
|
||||||
|
|
||||||
def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
|
def _maybe_reshape_input_tensor(tensor, column_name, output_rank):
|
||||||
"""Reshape the input tensor by the following rule.
|
"""Reshape the input tensor by the following rule.
|
||||||
|
|
||||||
@ -232,12 +125,13 @@ def _input_from_feature_columns(columns_to_tensors,
|
|||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
arguments = column._deep_embedding_lookup_arguments(
|
arguments = column._deep_embedding_lookup_arguments(
|
||||||
transformed_tensor)
|
transformed_tensor)
|
||||||
output_tensors.append(_embeddings_from_arguments(
|
output_tensors.append(
|
||||||
column,
|
fc._embeddings_from_arguments( # pylint: disable=protected-access
|
||||||
arguments,
|
column,
|
||||||
weight_collections,
|
arguments,
|
||||||
trainable,
|
weight_collections,
|
||||||
output_rank=output_rank))
|
trainable,
|
||||||
|
output_rank=output_rank))
|
||||||
|
|
||||||
except NotImplementedError as ee:
|
except NotImplementedError as ee:
|
||||||
try:
|
try:
|
||||||
@ -393,7 +287,7 @@ def _create_embedding_lookup(column,
|
|||||||
initializer=embedding_lookup_arguments.initializer,
|
initializer=embedding_lookup_arguments.initializer,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
collections=weight_collections)
|
collections=weight_collections)
|
||||||
if _is_variable(variable):
|
if fc._is_variable(variable): # pylint: disable=protected-access
|
||||||
variable = [variable]
|
variable = [variable]
|
||||||
else:
|
else:
|
||||||
variable = variable._get_variable_list() # pylint: disable=protected-access
|
variable = variable._get_variable_list() # pylint: disable=protected-access
|
||||||
@ -406,16 +300,6 @@ def _create_embedding_lookup(column,
|
|||||||
return variable, predictions
|
return variable, predictions
|
||||||
|
|
||||||
|
|
||||||
def _maybe_restore_from_checkpoint(checkpoint_path, variable):
|
|
||||||
if checkpoint_path is not None:
|
|
||||||
path, tensor_name = checkpoint_path
|
|
||||||
weights_to_restore = variable
|
|
||||||
if len(variable) == 1:
|
|
||||||
weights_to_restore = variable[0]
|
|
||||||
checkpoint_utils.init_from_checkpoint(path,
|
|
||||||
{tensor_name: weights_to_restore})
|
|
||||||
|
|
||||||
|
|
||||||
def _create_joint_embedding_lookup(columns_to_tensors,
|
def _create_joint_embedding_lookup(columns_to_tensors,
|
||||||
embedding_lookup_arguments,
|
embedding_lookup_arguments,
|
||||||
num_outputs,
|
num_outputs,
|
||||||
@ -451,7 +335,7 @@ def _create_joint_embedding_lookup(columns_to_tensors,
|
|||||||
initializer=init_ops.zeros_initializer(),
|
initializer=init_ops.zeros_initializer(),
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
collections=weight_collections)
|
collections=weight_collections)
|
||||||
if _is_variable(variable):
|
if fc._is_variable(variable): # pylint: disable=protected-access
|
||||||
variable = [variable]
|
variable = [variable]
|
||||||
else:
|
else:
|
||||||
variable = variable._get_variable_list() # pylint: disable=protected-access
|
variable = variable._get_variable_list() # pylint: disable=protected-access
|
||||||
@ -634,7 +518,7 @@ def weighted_sum_from_feature_columns(columns_to_tensors,
|
|||||||
predictions, shape=(-1, num_outputs)))
|
predictions, shape=(-1, num_outputs)))
|
||||||
column_to_variable[column] = variable
|
column_to_variable[column] = variable
|
||||||
_log_variable(variable)
|
_log_variable(variable)
|
||||||
_maybe_restore_from_checkpoint(column._checkpoint_path(), variable)
|
fc._maybe_restore_from_checkpoint(column._checkpoint_path(), variable) # pylint: disable=protected-access
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
predictions_no_bias = math_ops.add_n(output_tensors)
|
predictions_no_bias = math_ops.add_n(output_tensors)
|
||||||
bias = contrib_variables.model_variable(
|
bias = contrib_variables.model_variable(
|
||||||
@ -827,10 +711,10 @@ def parse_feature_columns_from_sequence_examples(
|
|||||||
def _log_variable(variable):
|
def _log_variable(variable):
|
||||||
if isinstance(variable, list):
|
if isinstance(variable, list):
|
||||||
for var in variable:
|
for var in variable:
|
||||||
if _is_variable(variable):
|
if fc._is_variable(variable): # pylint: disable=protected-access
|
||||||
logging.info('Created variable %s, with device=%s', var.name,
|
logging.info('Created variable %s, with device=%s', var.name,
|
||||||
var.device)
|
var.device)
|
||||||
elif _is_variable(variable):
|
elif fc._is_variable(variable): # pylint: disable=protected-access
|
||||||
logging.info('Created variable %s, with device=%s', variable.name,
|
logging.info('Created variable %s, with device=%s', variable.name,
|
||||||
variable.device)
|
variable.device)
|
||||||
|
|
||||||
|
@ -597,12 +597,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
"income":
|
"income":
|
||||||
constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]),
|
constant_op.constant([[20.3, 10], [110.3, 0.4], [-3.0, 30.4]]),
|
||||||
}
|
}
|
||||||
output = feature_column_ops.input_from_feature_columns(features, [
|
columns = [one_hot_column, embedding_column, real_valued_column]
|
||||||
one_hot_column, embedding_column, real_valued_column])
|
output = feature_column_ops.input_from_feature_columns(features, columns)
|
||||||
|
output_core = fc_core.make_input_layer(features, columns)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
|
self.assertAllEqual(output.eval().shape, [3, 2 + 4 + 10])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
|
||||||
|
|
||||||
def testRealValuedColumn(self):
|
def testRealValuedColumn(self):
|
||||||
real_valued = feature_column.real_valued_column("price")
|
real_valued = feature_column.real_valued_column("price")
|
||||||
@ -712,11 +715,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
one_hot_column = feature_column.one_hot_column(weighted_ids_column)
|
one_hot_column = feature_column.one_hot_column(weighted_ids_column)
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[one_hot_column])
|
[one_hot_column])
|
||||||
|
output_core = fc_core.make_input_layer(features, [one_hot_column])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
|
self.assertAllEqual([[0, 0, 10., 0], [0, 20., 0, 0], [30., 0, 40., 0]],
|
||||||
output.eval())
|
output.eval())
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval(), output_core.eval())
|
||||||
|
|
||||||
def testOneHotColumnFromSparseColumnWithKeysSucceedsForDNN(self):
|
def testOneHotColumnFromSparseColumnWithKeysSucceedsForDNN(self):
|
||||||
ids_column = feature_column.sparse_column_with_keys(
|
ids_column = feature_column.sparse_column_with_keys(
|
||||||
@ -729,12 +735,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
features = {"ids": ids_tensor}
|
features = {"ids": ids_tensor}
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[one_hot_sparse])
|
[one_hot_sparse])
|
||||||
|
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
|
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 0, 0]],
|
||||||
output.eval())
|
output.eval())
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval(), output_core.eval())
|
||||||
|
|
||||||
def testOneHotColumnFromMultivalentSparseColumnWithKeysSucceedsForDNN(self):
|
def testOneHotColumnFromMultivalentSparseColumnWithKeysSucceedsForDNN(self):
|
||||||
ids_column = feature_column.sparse_column_with_keys(
|
ids_column = feature_column.sparse_column_with_keys(
|
||||||
@ -747,12 +756,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
features = {"ids": ids_tensor}
|
features = {"ids": ids_tensor}
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[one_hot_sparse])
|
[one_hot_sparse])
|
||||||
|
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
||||||
output.eval())
|
output.eval())
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval(), output_core.eval())
|
||||||
|
|
||||||
def testOneHotColumnFromSparseColumnWithIntegerizedFeaturePassesForDNN(self):
|
def testOneHotColumnFromSparseColumnWithIntegerizedFeaturePassesForDNN(self):
|
||||||
ids_column = feature_column.sparse_column_with_integerized_feature(
|
ids_column = feature_column.sparse_column_with_integerized_feature(
|
||||||
@ -767,10 +779,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
}
|
}
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[one_hot_sparse])
|
[one_hot_sparse])
|
||||||
|
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
self.assertAllEqual([[0, 0, 1, 0], [0, 1, 0, 0], [1, 0, 1, 0]],
|
||||||
output.eval())
|
output.eval())
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval(), output_core.eval())
|
||||||
|
|
||||||
def testOneHotColumnFromSparseColumnWithHashBucketSucceedsForDNN(self):
|
def testOneHotColumnFromSparseColumnWithHashBucketSucceedsForDNN(self):
|
||||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("feat", 10)
|
hashed_sparse = feature_column.sparse_column_with_hash_bucket("feat", 10)
|
||||||
@ -782,10 +797,13 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
one_hot_sparse = feature_column.one_hot_column(hashed_sparse)
|
one_hot_sparse = feature_column.one_hot_column(hashed_sparse)
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[one_hot_sparse])
|
[one_hot_sparse])
|
||||||
|
output_core = fc_core.make_input_layer(features, [one_hot_sparse])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual([3, 10], output.eval().shape)
|
self.assertAllEqual([3, 10], output.eval().shape)
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval(), output_core.eval())
|
||||||
|
|
||||||
def testEmbeddingColumnSucceedsForDNN(self):
|
def testEmbeddingColumnSucceedsForDNN(self):
|
||||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||||
@ -797,9 +815,12 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
|
embeded_sparse = feature_column.embedding_column(hashed_sparse, 10)
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[embeded_sparse])
|
[embeded_sparse])
|
||||||
|
output_core = fc_core.make_input_layer(features, [embeded_sparse])
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
self.assertAllEqual(output.eval().shape, [4, 10])
|
self.assertAllEqual(output.eval().shape, [4, 10])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
|
||||||
|
|
||||||
def testScatteredEmbeddingColumnSucceedsForDNN(self):
|
def testScatteredEmbeddingColumnSucceedsForDNN(self):
|
||||||
wire_tensor = sparse_tensor.SparseTensor(
|
wire_tensor = sparse_tensor.SparseTensor(
|
||||||
@ -838,12 +859,15 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
initializer=init_ops.constant_initializer(init_value))
|
initializer=init_ops.constant_initializer(init_value))
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[embeded_sparse])
|
[embeded_sparse])
|
||||||
|
output_core = fc_core.make_input_layer(features, [embeded_sparse])
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
output_eval = output.eval()
|
output_eval = output.eval()
|
||||||
self.assertAllEqual(output_eval.shape, [2, 10])
|
self.assertAllEqual(output_eval.shape, [2, 10])
|
||||||
self.assertAllClose(output_eval, np.tile(init_value, [2, 10]))
|
self.assertAllClose(output_eval, np.tile(init_value, [2, 10]))
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval(), output_core.eval())
|
||||||
|
|
||||||
def testEmbeddingColumnWithMultipleInitializersFails(self):
|
def testEmbeddingColumnWithMultipleInitializersFails(self):
|
||||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||||
@ -889,10 +913,14 @@ class CreateInputLayersForDNNsTest(test.TestCase):
|
|||||||
embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
|
embeded_sparse = feature_column.embedding_column(weighted_ids, 10)
|
||||||
output = feature_column_ops.input_from_feature_columns(features,
|
output = feature_column_ops.input_from_feature_columns(features,
|
||||||
[embeded_sparse])
|
[embeded_sparse])
|
||||||
|
output_core = fc_core.make_input_layer(features, [embeded_sparse])
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(output.eval().shape, [2, 10])
|
self.assertAllEqual(output.eval().shape, [2, 10])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval().shape, output_core.eval().shape)
|
||||||
|
|
||||||
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
|
def testEmbeddingColumnWithIntegerWeightedSparseColumnSucceedsForDNN(self):
|
||||||
"""Same as the previous test, but with integer weights."""
|
"""Same as the previous test, but with integer weights."""
|
||||||
@ -1534,9 +1562,12 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features = {"wire": wire_tensor}
|
features = {"wire": wire_tensor}
|
||||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||||
features, [hashed_sparse], num_outputs=5)
|
features, [hashed_sparse], num_outputs=5)
|
||||||
|
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||||
|
|
||||||
def testSparseIntColumn(self):
|
def testSparseIntColumn(self):
|
||||||
"""Tests a sparse column with int values."""
|
"""Tests a sparse column with int values."""
|
||||||
@ -1549,9 +1580,12 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features = {"wire": wire_tensor}
|
features = {"wire": wire_tensor}
|
||||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||||
features, [hashed_sparse], num_outputs=5)
|
features, [hashed_sparse], num_outputs=5)
|
||||||
|
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||||
|
|
||||||
def testSparseColumnWithDenseInputTensor(self):
|
def testSparseColumnWithDenseInputTensor(self):
|
||||||
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
hashed_sparse = feature_column.sparse_column_with_hash_bucket("wire", 10)
|
||||||
@ -1560,9 +1594,12 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features = {"wire": wire_tensor}
|
features = {"wire": wire_tensor}
|
||||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||||
features, [hashed_sparse], num_outputs=5)
|
features, [hashed_sparse], num_outputs=5)
|
||||||
|
logits_core = fc_core.make_linear_model(features, [hashed_sparse], units=5)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||||
|
|
||||||
def testWeightedSparseColumn(self):
|
def testWeightedSparseColumn(self):
|
||||||
ids = feature_column.sparse_column_with_keys("ids",
|
ids = feature_column.sparse_column_with_keys("ids",
|
||||||
@ -1579,10 +1616,13 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features = {"ids": ids_tensor, "weights": weights_tensor}
|
features = {"ids": ids_tensor, "weights": weights_tensor}
|
||||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||||
features, [weighted_ids], num_outputs=5)
|
features, [weighted_ids], num_outputs=5)
|
||||||
|
logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||||
|
|
||||||
def testWeightedSparseColumnWithDenseInputTensor(self):
|
def testWeightedSparseColumnWithDenseInputTensor(self):
|
||||||
ids = feature_column.sparse_column_with_keys(
|
ids = feature_column.sparse_column_with_keys(
|
||||||
@ -1594,11 +1634,14 @@ class WeightedSumTest(test.TestCase):
|
|||||||
features = {"ids": ids_tensor, "weights": weights_tensor}
|
features = {"ids": ids_tensor, "weights": weights_tensor}
|
||||||
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
logits, _, _ = feature_column_ops.weighted_sum_from_feature_columns(
|
||||||
features, [weighted_ids], num_outputs=5)
|
features, [weighted_ids], num_outputs=5)
|
||||||
|
logits_core = fc_core.make_linear_model(features, [weighted_ids], units=5)
|
||||||
|
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
variables_lib.global_variables_initializer().run()
|
variables_lib.global_variables_initializer().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
self.assertAllEqual(logits.eval().shape, [2, 5])
|
self.assertAllEqual(logits.eval().shape, [2, 5])
|
||||||
|
# Verify cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(logits.eval(), logits_core.eval())
|
||||||
|
|
||||||
def testCrossedColumn(self):
|
def testCrossedColumn(self):
|
||||||
a = feature_column.sparse_column_with_hash_bucket(
|
a = feature_column.sparse_column_with_hash_bucket(
|
||||||
@ -1649,6 +1692,8 @@ class WeightedSumTest(test.TestCase):
|
|||||||
output, column_to_variable, _ = (
|
output, column_to_variable, _ = (
|
||||||
feature_column_ops.weighted_sum_from_feature_columns(
|
feature_column_ops.weighted_sum_from_feature_columns(
|
||||||
features, [movies], num_outputs=1))
|
features, [movies], num_outputs=1))
|
||||||
|
logits_core = fc_core.make_linear_model(features, [movies])
|
||||||
|
|
||||||
with self.test_session() as sess:
|
with self.test_session() as sess:
|
||||||
variables_lib.initialize_all_variables().run()
|
variables_lib.initialize_all_variables().run()
|
||||||
lookup_ops.tables_initializer().run()
|
lookup_ops.tables_initializer().run()
|
||||||
@ -1659,6 +1704,8 @@ class WeightedSumTest(test.TestCase):
|
|||||||
# score for first example = 0.3 (matrix) + 0.1 (head-on) = 0.4
|
# score for first example = 0.3 (matrix) + 0.1 (head-on) = 0.4
|
||||||
# score for second example = 0.5 (winter sleep)
|
# score for second example = 0.5 (winter sleep)
|
||||||
self.assertAllClose(output.eval(), [[0.4], [0.5]])
|
self.assertAllClose(output.eval(), [[0.4], [0.5]])
|
||||||
|
# Cross compatibility: Core builder output should equal to contrib.
|
||||||
|
self.assertAllEqual(output.eval().shape, logits_core.eval().shape)
|
||||||
|
|
||||||
def testRealValuedColumnWithMultiDimensions(self):
|
def testRealValuedColumnWithMultiDimensions(self):
|
||||||
real_valued = feature_column.real_valued_column("price", 2)
|
real_valued = feature_column.real_valued_column("price", 2)
|
||||||
|
@ -36,7 +36,8 @@ def xavier_initializer(uniform=True, seed=None, dtype=dtypes.float32):
|
|||||||
Xavier Glorot and Yoshua Bengio (2010):
|
Xavier Glorot and Yoshua Bengio (2010):
|
||||||
[Understanding the difficulty of training deep feedforward neural
|
[Understanding the difficulty of training deep feedforward neural
|
||||||
networks. International conference on artificial intelligence and
|
networks. International conference on artificial intelligence and
|
||||||
statistics.](http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.207.2059&rep=rep1&type=pdf)
|
statistics.](
|
||||||
|
http://www.jmlr.org/proceedings/papers/v9/glorot10a/glorot10a.pdf)
|
||||||
|
|
||||||
This initializer is designed to keep the scale of the gradients roughly the
|
This initializer is designed to keep the scale of the gradients roughly the
|
||||||
same in all layers. In uniform distribution this ends up being the range:
|
same in all layers. In uniform distribution this ends up being the range:
|
||||||
|
@ -102,9 +102,10 @@ def _linear_learning_rate(num_linear_feature_columns):
|
|||||||
def _add_hidden_layer_summary(value, tag):
|
def _add_hidden_layer_summary(value, tag):
|
||||||
summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
|
summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
|
||||||
summary.histogram("%s/activation" % tag, value)
|
summary.histogram("%s/activation" % tag, value)
|
||||||
|
|
||||||
|
|
||||||
def _add_layer_summary(value, tag):
|
def _add_layer_summary(value, tag):
|
||||||
summary.scalar("%s/fraction_of_zero_values" % tag,
|
summary.scalar("%s/fraction_of_zero_values" % tag, nn.zero_fraction(value))
|
||||||
nn.zero_fraction(value))
|
|
||||||
summary.histogram("%s/activation" % tag, value)
|
summary.histogram("%s/activation" % tag, value)
|
||||||
|
|
||||||
|
|
||||||
|
@ -19,7 +19,6 @@ from __future__ import division
|
|||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
from tensorflow.contrib.framework.python.framework import deprecated
|
|
||||||
from tensorflow.contrib.layers.python.layers import optimizers
|
from tensorflow.contrib.layers.python.layers import optimizers
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
from tensorflow.contrib.learn.python.learn.estimators import estimator
|
||||||
|
@ -32,6 +32,7 @@ from tensorflow.python.ops import state_ops
|
|||||||
from tensorflow.python.summary import summary
|
from tensorflow.python.summary import summary
|
||||||
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
from tensorflow.python.ops.control_flow_ops import with_dependencies
|
||||||
from tensorflow.python.platform import tf_logging as logging
|
from tensorflow.python.platform import tf_logging as logging
|
||||||
|
from tensorflow.python.summary import summary
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
from tensorflow.python.training.session_run_hook import SessionRunArgs
|
||||||
|
|
||||||
|
@ -20,7 +20,6 @@ from __future__ import print_function
|
|||||||
|
|
||||||
from tensorflow.contrib import layers
|
from tensorflow.contrib import layers
|
||||||
from tensorflow.contrib import rnn as rnn_cell
|
from tensorflow.contrib import rnn as rnn_cell
|
||||||
from tensorflow.contrib.framework.python.framework import deprecated
|
|
||||||
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
from tensorflow.contrib.layers.python.layers import feature_column_ops
|
||||||
from tensorflow.contrib.layers.python.layers import optimizers
|
from tensorflow.contrib.layers.python.layers import optimizers
|
||||||
from tensorflow.contrib.learn.python.learn.estimators import constants
|
from tensorflow.contrib.learn.python.learn.estimators import constants
|
||||||
|
@ -455,6 +455,7 @@ class LegacyConstructorTest(test.TestCase):
|
|||||||
return {'inputs': inputs}, labels
|
return {'inputs': inputs}, labels
|
||||||
return input_fn
|
return input_fn
|
||||||
|
|
||||||
|
|
||||||
# TODO(jtbates): move all tests below to a benchmark test.
|
# TODO(jtbates): move all tests below to a benchmark test.
|
||||||
class StateSavingRNNEstimatorLearningTest(test.TestCase):
|
class StateSavingRNNEstimatorLearningTest(test.TestCase):
|
||||||
"""Learning tests for state saving RNN Estimators."""
|
"""Learning tests for state saving RNN Estimators."""
|
||||||
|
@ -22,6 +22,7 @@ import os
|
|||||||
import tempfile
|
import tempfile
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
from tensorflow.contrib.learn.python.learn import estimator as estimator_lib
|
||||||
from tensorflow.contrib.learn.python.learn import evaluable
|
from tensorflow.contrib.learn.python.learn import evaluable
|
||||||
from tensorflow.contrib.learn.python.learn import experiment
|
from tensorflow.contrib.learn.python.learn import experiment
|
||||||
from tensorflow.contrib.learn.python.learn import run_config
|
from tensorflow.contrib.learn.python.learn import run_config
|
||||||
@ -38,6 +39,7 @@ from tensorflow.python.training import saver
|
|||||||
from tensorflow.python.training import server_lib
|
from tensorflow.python.training import server_lib
|
||||||
from tensorflow.python.training import session_run_hook
|
from tensorflow.python.training import session_run_hook
|
||||||
from tensorflow.python.util import compat
|
from tensorflow.python.util import compat
|
||||||
|
from tensorflow.python.util import tf_inspect
|
||||||
|
|
||||||
|
|
||||||
class SheepCounter(object):
|
class SheepCounter(object):
|
||||||
@ -119,6 +121,15 @@ class TestBaseEstimator(object):
|
|||||||
compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
|
compat.as_bytes(export_dir_base), compat.as_bytes('bogus_timestamp'))
|
||||||
|
|
||||||
|
|
||||||
|
def _check_method_supports_args(method, kwargs):
|
||||||
|
"""Checks that the given method supports the given args."""
|
||||||
|
supported_args = tuple(tf_inspect.getargspec(method).args)
|
||||||
|
for kwarg in kwargs:
|
||||||
|
if kwarg not in supported_args:
|
||||||
|
raise ValueError(
|
||||||
|
'Argument `{}` is not supported in method {}.'.format(kwarg, method))
|
||||||
|
|
||||||
|
|
||||||
class TestEstimator(
|
class TestEstimator(
|
||||||
TestBaseEstimator, evaluable.Evaluable, trainable.Trainable):
|
TestBaseEstimator, evaluable.Evaluable, trainable.Trainable):
|
||||||
|
|
||||||
@ -126,9 +137,12 @@ class TestEstimator(
|
|||||||
super(TestEstimator, self).__init__(config, max_evals, eval_dict)
|
super(TestEstimator, self).__init__(config, max_evals, eval_dict)
|
||||||
tf_logging.info('Create Estimator')
|
tf_logging.info('Create Estimator')
|
||||||
|
|
||||||
|
def evaluate(self, **kwargs):
|
||||||
|
_check_method_supports_args(evaluable.Evaluable.evaluate, kwargs)
|
||||||
|
return super(TestEstimator, self).evaluate(**kwargs)
|
||||||
|
|
||||||
def fit(self, **kwargs):
|
def fit(self, **kwargs):
|
||||||
if 'hooks' in kwargs:
|
_check_method_supports_args(trainable.Trainable.fit, kwargs)
|
||||||
raise ValueError('`hooks` is defined in core Estimator')
|
|
||||||
if 'monitors' in kwargs:
|
if 'monitors' in kwargs:
|
||||||
self.monitors = kwargs['monitors']
|
self.monitors = kwargs['monitors']
|
||||||
return super(TestEstimator, self).train(**kwargs)
|
return super(TestEstimator, self).train(**kwargs)
|
||||||
@ -136,6 +150,13 @@ class TestEstimator(
|
|||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
raise ValueError('`train` is not defined in Estimator.')
|
raise ValueError('`train` is not defined in Estimator.')
|
||||||
|
|
||||||
|
def export_savedmodel(
|
||||||
|
self, export_dir_base, serving_input_fn, **kwargs):
|
||||||
|
_check_method_supports_args(
|
||||||
|
estimator_lib.Estimator.export_savedmodel, kwargs)
|
||||||
|
return super(TestEstimator, self).export_savedmodel(
|
||||||
|
export_dir_base, serving_input_fn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
|
class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
|
||||||
|
|
||||||
@ -144,17 +165,22 @@ class TestCoreEstimator(TestBaseEstimator, core_estimator.Estimator):
|
|||||||
tf_logging.info('Create Core Estimator')
|
tf_logging.info('Create Core Estimator')
|
||||||
|
|
||||||
def evaluate(self, **kwargs):
|
def evaluate(self, **kwargs):
|
||||||
if 'eval_metrics' in kwargs:
|
_check_method_supports_args(core_estimator.Estimator.evaluate, kwargs)
|
||||||
raise ValueError('`eval_metrics` is not defined in core Estimator')
|
|
||||||
return super(TestCoreEstimator, self).evaluate(**kwargs)
|
return super(TestCoreEstimator, self).evaluate(**kwargs)
|
||||||
|
|
||||||
def train(self, **kwargs):
|
def train(self, **kwargs):
|
||||||
if 'monitors' in kwargs:
|
_check_method_supports_args(core_estimator.Estimator.train, kwargs)
|
||||||
raise ValueError('`monitors` is not defined in core Estimator')
|
|
||||||
if 'hooks' in kwargs:
|
if 'hooks' in kwargs:
|
||||||
self.monitors = kwargs['hooks']
|
self.monitors = kwargs['hooks']
|
||||||
return super(TestCoreEstimator, self).train(**kwargs)
|
return super(TestCoreEstimator, self).train(**kwargs)
|
||||||
|
|
||||||
|
def export_savedmodel(
|
||||||
|
self, export_dir_base, serving_input_receiver_fn, **kwargs):
|
||||||
|
_check_method_supports_args(
|
||||||
|
core_estimator.Estimator.export_savedmodel, kwargs)
|
||||||
|
return super(TestCoreEstimator, self).export_savedmodel(
|
||||||
|
export_dir_base, serving_input_receiver_fn, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
class _NoopHook(session_run_hook.SessionRunHook):
|
class _NoopHook(session_run_hook.SessionRunHook):
|
||||||
pass
|
pass
|
||||||
@ -184,6 +210,23 @@ class ExperimentTest(test.TestCase):
|
|||||||
eval_input_fn='eval_input',
|
eval_input_fn='eval_input',
|
||||||
eval_metrics='eval_metrics')
|
eval_metrics='eval_metrics')
|
||||||
|
|
||||||
|
def test_default_output_alternative_key_core_estimator(self):
|
||||||
|
est = TestCoreEstimator()
|
||||||
|
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||||
|
est,
|
||||||
|
default_output_alternative_key='export_key',
|
||||||
|
exports_to_keep=None)
|
||||||
|
ex = experiment.Experiment(
|
||||||
|
est,
|
||||||
|
train_input_fn='train_input',
|
||||||
|
eval_input_fn='eval_input',
|
||||||
|
train_steps=100,
|
||||||
|
eval_steps=100,
|
||||||
|
export_strategies=export_strategy)
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
ValueError, 'default_output_alternative_key is not supported'):
|
||||||
|
ex.train_and_evaluate()
|
||||||
|
|
||||||
def test_train(self):
|
def test_train(self):
|
||||||
for est in self._estimators_for_tests():
|
for est in self._estimators_for_tests():
|
||||||
eval_metrics = 'eval_metrics' if not isinstance(
|
eval_metrics = 'eval_metrics' if not isinstance(
|
||||||
@ -508,7 +551,9 @@ class ExperimentTest(test.TestCase):
|
|||||||
eval_metrics = 'eval_metrics' if not isinstance(
|
eval_metrics = 'eval_metrics' if not isinstance(
|
||||||
est, core_estimator.Estimator) else None
|
est, core_estimator.Estimator) else None
|
||||||
export_strategy_1 = saved_model_export_utils.make_export_strategy(
|
export_strategy_1 = saved_model_export_utils.make_export_strategy(
|
||||||
est, 'export_input_1', exports_to_keep=None)
|
est,
|
||||||
|
None if isinstance(est, core_estimator.Estimator) else 'export_1',
|
||||||
|
exports_to_keep=None)
|
||||||
|
|
||||||
ex = experiment.Experiment(
|
ex = experiment.Experiment(
|
||||||
est,
|
est,
|
||||||
@ -531,9 +576,13 @@ class ExperimentTest(test.TestCase):
|
|||||||
# After reset with list, the count should increase with the number of
|
# After reset with list, the count should increase with the number of
|
||||||
# items.
|
# items.
|
||||||
export_strategy_2 = saved_model_export_utils.make_export_strategy(
|
export_strategy_2 = saved_model_export_utils.make_export_strategy(
|
||||||
est, 'export_input_2', exports_to_keep=None)
|
est,
|
||||||
|
None if isinstance(est, core_estimator.Estimator) else 'export_2',
|
||||||
|
exports_to_keep=None)
|
||||||
export_strategy_3 = saved_model_export_utils.make_export_strategy(
|
export_strategy_3 = saved_model_export_utils.make_export_strategy(
|
||||||
est, 'export_input_3', exports_to_keep=None)
|
est,
|
||||||
|
None if isinstance(est, core_estimator.Estimator) else 'export_3',
|
||||||
|
exports_to_keep=None)
|
||||||
|
|
||||||
old_es = ex.reset_export_strategies(
|
old_es = ex.reset_export_strategies(
|
||||||
[export_strategy_2, export_strategy_3])
|
[export_strategy_2, export_strategy_3])
|
||||||
@ -547,7 +596,9 @@ class ExperimentTest(test.TestCase):
|
|||||||
est, core_estimator.Estimator) else None
|
est, core_estimator.Estimator) else None
|
||||||
noop_hook = _NoopHook()
|
noop_hook = _NoopHook()
|
||||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||||
est, 'export_input', exports_to_keep=None)
|
est,
|
||||||
|
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||||
|
exports_to_keep=None)
|
||||||
ex = experiment.Experiment(
|
ex = experiment.Experiment(
|
||||||
est,
|
est,
|
||||||
train_input_fn='train_input',
|
train_input_fn='train_input',
|
||||||
@ -625,7 +676,9 @@ class ExperimentTest(test.TestCase):
|
|||||||
est, core_estimator.Estimator) else None
|
est, core_estimator.Estimator) else None
|
||||||
noop_hook = _NoopHook()
|
noop_hook = _NoopHook()
|
||||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||||
est, 'export_input', exports_to_keep=None)
|
est,
|
||||||
|
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||||
|
exports_to_keep=None)
|
||||||
ex = experiment.Experiment(
|
ex = experiment.Experiment(
|
||||||
est,
|
est,
|
||||||
train_input_fn='train_input',
|
train_input_fn='train_input',
|
||||||
@ -646,7 +699,9 @@ class ExperimentTest(test.TestCase):
|
|||||||
eval_metrics = 'eval_metrics' if not isinstance(
|
eval_metrics = 'eval_metrics' if not isinstance(
|
||||||
est, core_estimator.Estimator) else None
|
est, core_estimator.Estimator) else None
|
||||||
export_strategy = saved_model_export_utils.make_export_strategy(
|
export_strategy = saved_model_export_utils.make_export_strategy(
|
||||||
est, 'export_input', exports_to_keep=None)
|
est,
|
||||||
|
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||||
|
exports_to_keep=None)
|
||||||
ex = experiment.Experiment(
|
ex = experiment.Experiment(
|
||||||
est,
|
est,
|
||||||
train_input_fn='train_input',
|
train_input_fn='train_input',
|
||||||
@ -796,7 +851,9 @@ class ExperimentTest(test.TestCase):
|
|||||||
def test_test(self):
|
def test_test(self):
|
||||||
for est in self._estimators_for_tests():
|
for est in self._estimators_for_tests():
|
||||||
exp_strategy = saved_model_export_utils.make_export_strategy(
|
exp_strategy = saved_model_export_utils.make_export_strategy(
|
||||||
est, 'export_input', exports_to_keep=None)
|
est,
|
||||||
|
None if isinstance(est, core_estimator.Estimator) else 'export_input',
|
||||||
|
exports_to_keep=None)
|
||||||
ex = experiment.Experiment(
|
ex = experiment.Experiment(
|
||||||
est,
|
est,
|
||||||
train_input_fn='train_input',
|
train_input_fn='train_input',
|
||||||
|
@ -42,6 +42,7 @@ from tensorflow.contrib.learn.python.learn.estimators import constants
|
|||||||
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
from tensorflow.contrib.learn.python.learn.estimators import prediction_key
|
||||||
from tensorflow.contrib.learn.python.learn.utils import gc
|
from tensorflow.contrib.learn.python.learn.utils import gc
|
||||||
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
from tensorflow.contrib.learn.python.learn.utils import input_fn_utils
|
||||||
|
from tensorflow.python.estimator import estimator as core_estimator
|
||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.framework import errors_impl
|
from tensorflow.python.framework import errors_impl
|
||||||
from tensorflow.python.platform import gfile
|
from tensorflow.python.platform import gfile
|
||||||
@ -352,7 +353,8 @@ def make_export_strategy(serving_input_fn,
|
|||||||
`InputFnOps`.
|
`InputFnOps`.
|
||||||
default_output_alternative_key: the name of the head to serve when an
|
default_output_alternative_key: the name of the head to serve when an
|
||||||
incoming serving request does not explicitly request a specific head.
|
incoming serving request does not explicitly request a specific head.
|
||||||
Not needed for single-headed models.
|
Must be `None` if the estimator inherits from ${tf.estimator.Estimator}
|
||||||
|
or for single-headed models.
|
||||||
assets_extra: A dict specifying how to populate the assets.extra directory
|
assets_extra: A dict specifying how to populate the assets.extra directory
|
||||||
within the exported SavedModel. Each key should give the destination
|
within the exported SavedModel. Each key should give the destination
|
||||||
path (including the filename) relative to the assets.extra directory.
|
path (including the filename) relative to the assets.extra directory.
|
||||||
@ -384,14 +386,30 @@ def make_export_strategy(serving_input_fn,
|
|||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
The string path to the exported directory.
|
The string path to the exported directory.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: If `estimator` is a ${tf.estimator.Estimator} instance
|
||||||
|
and `default_output_alternative_key` was specified.
|
||||||
"""
|
"""
|
||||||
export_result = estimator.export_savedmodel(
|
if isinstance(estimator, core_estimator.Estimator):
|
||||||
export_dir_base,
|
if default_output_alternative_key is not None:
|
||||||
serving_input_fn,
|
raise ValueError(
|
||||||
default_output_alternative_key=default_output_alternative_key,
|
'default_output_alternative_key is not supported in core '
|
||||||
assets_extra=assets_extra,
|
'Estimator. Given: {}'.format(default_output_alternative_key))
|
||||||
as_text=as_text,
|
export_result = estimator.export_savedmodel(
|
||||||
checkpoint_path=checkpoint_path)
|
export_dir_base,
|
||||||
|
serving_input_fn,
|
||||||
|
assets_extra=assets_extra,
|
||||||
|
as_text=as_text,
|
||||||
|
checkpoint_path=checkpoint_path)
|
||||||
|
else:
|
||||||
|
export_result = estimator.export_savedmodel(
|
||||||
|
export_dir_base,
|
||||||
|
serving_input_fn,
|
||||||
|
default_output_alternative_key=default_output_alternative_key,
|
||||||
|
assets_extra=assets_extra,
|
||||||
|
as_text=as_text,
|
||||||
|
checkpoint_path=checkpoint_path)
|
||||||
|
|
||||||
garbage_collect_exports(export_dir_base, exports_to_keep)
|
garbage_collect_exports(export_dir_base, exports_to_keep)
|
||||||
return export_result
|
return export_result
|
||||||
|
@ -1,9 +1,9 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:__subpackages__"])
|
||||||
|
|
||||||
licenses(["notice"]) # Apache 2.0
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
exports_files(["LICENSE"])
|
exports_files(["LICENSE"])
|
||||||
|
|
||||||
package(default_visibility = ["//tensorflow:__subpackages__"])
|
|
||||||
|
|
||||||
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
load("//tensorflow:tensorflow.bzl", "cuda_py_tests")
|
||||||
|
|
||||||
py_library(
|
py_library(
|
||||||
|
@ -12,8 +12,8 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
"""##Signal ops.
|
||||||
|
|
||||||
"""
|
|
||||||
@@frames
|
@@frames
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
"""Signal ops."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -33,34 +33,34 @@ class FramesTest(test.TestCase):
|
|||||||
with self.test_session():
|
with self.test_session():
|
||||||
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
|
tensor = constant_op.constant(np.arange(9152), dtypes.int32)
|
||||||
tensor = array_ops.expand_dims(tensor, 0)
|
tensor = array_ops.expand_dims(tensor, 0)
|
||||||
|
|
||||||
result = shape_ops.frames(tensor, 512, 180)
|
result = shape_ops.frames(tensor, 512, 180)
|
||||||
result = result.eval()
|
result = result.eval()
|
||||||
|
|
||||||
expected = np.tile(np.arange(512), (49, 1))
|
expected = np.tile(np.arange(512), (49, 1))
|
||||||
expected += np.tile(np.arange(49) * 180, (512, 1)).T
|
expected += np.tile(np.arange(49) * 180, (512, 1)).T
|
||||||
|
|
||||||
expected = np.expand_dims(expected, axis=0)
|
expected = np.expand_dims(expected, axis=0)
|
||||||
expected = np.array(expected, dtype=np.int32)
|
expected = np.array(expected, dtype=np.int32)
|
||||||
|
|
||||||
self.assertAllEqual(expected, result)
|
self.assertAllEqual(expected, result)
|
||||||
|
|
||||||
def test_mapping_of_indices_with_padding(self):
|
def test_mapping_of_indices_with_padding(self):
|
||||||
with self.test_session():
|
with self.test_session():
|
||||||
tensor = constant_op.constant(np.arange(10000), dtypes.int32)
|
tensor = constant_op.constant(np.arange(10000), dtypes.int32)
|
||||||
tensor = array_ops.expand_dims(tensor, 0)
|
tensor = array_ops.expand_dims(tensor, 0)
|
||||||
|
|
||||||
result = shape_ops.frames(tensor, 512, 192)
|
result = shape_ops.frames(tensor, 512, 192)
|
||||||
result = result.eval()
|
result = result.eval()
|
||||||
|
|
||||||
expected = np.tile(np.arange(512), (51, 1))
|
expected = np.tile(np.arange(512), (51, 1))
|
||||||
expected += np.tile(np.arange(51) * 192, (512, 1)).T
|
expected += np.tile(np.arange(51) * 192, (512, 1)).T
|
||||||
|
|
||||||
expected[expected >= 10000] = 0
|
expected[expected >= 10000] = 0
|
||||||
|
|
||||||
expected = np.expand_dims(expected, axis=0)
|
expected = np.expand_dims(expected, axis=0)
|
||||||
expected = np.array(expected, dtype=np.int32)
|
expected = np.array(expected, dtype=np.int32)
|
||||||
|
|
||||||
self.assertAllEqual(expected, result)
|
self.assertAllEqual(expected, result)
|
||||||
|
|
||||||
|
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
"""Signal ops."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
|
@ -12,6 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ==============================================================================
|
# ==============================================================================
|
||||||
|
"""General shape ops for frames."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
@ -23,59 +24,64 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
|
||||||
def frames(signal, frame_length, frame_step, name=None):
|
def frames(signal, frame_length, frame_step, name=None):
|
||||||
"""Frame a signal into overlapping frames.
|
"""Frame a signal into overlapping frames.
|
||||||
|
|
||||||
May be used in front of spectral functions.
|
May be used in front of spectral functions.
|
||||||
|
|
||||||
For example:
|
For example:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
pcm = tf.placeholder(tf.float32, [None, 9152])
|
pcm = tf.placeholder(tf.float32, [None, 9152])
|
||||||
frames = tf.contrib.signal.frames(pcm, 512, 180)
|
frames = tf.contrib.signal.frames(pcm, 512, 180)
|
||||||
magspec = tf.abs(tf.spectral.rfft(frames, [512]))
|
magspec = tf.abs(tf.spectral.rfft(frames, [512]))
|
||||||
image = tf.expand_dims(magspec, 3)
|
image = tf.expand_dims(magspec, 3)
|
||||||
```
|
```
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
signal: A `Tensor` of shape `[batch_size, signal_length]`.
|
signal: A `Tensor` of shape `[batch_size, signal_length]`.
|
||||||
frame_length: An `int32` or `int64` `Tensor`. The length of each frame.
|
frame_length: An `int32` or `int64` `Tensor`. The length of each frame.
|
||||||
frame_step: An `int32` or `int64` `Tensor`. The step between frames.
|
frame_step: An `int32` or `int64` `Tensor`. The step between frames.
|
||||||
name: A name for the operation (optional).
|
name: A name for the operation (optional).
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
|
A `Tensor` of frames with shape `[batch_size, num_frames, frame_length]`.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if signal does not have rank 2.
|
||||||
"""
|
"""
|
||||||
with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
|
with ops.name_scope(name, "frames", [signal, frame_length, frame_step]):
|
||||||
signal = ops.convert_to_tensor(signal, name="signal")
|
signal = ops.convert_to_tensor(signal, name="signal")
|
||||||
frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
|
frame_length = ops.convert_to_tensor(frame_length, name="frame_length")
|
||||||
frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
|
frame_step = ops.convert_to_tensor(frame_step, name="frame_step")
|
||||||
|
|
||||||
signal_rank = signal.shape.ndims
|
signal_rank = signal.shape.ndims
|
||||||
|
|
||||||
if signal_rank != 2:
|
if signal_rank != 2:
|
||||||
raise ValueError("expected signal to have rank 2 but was " + signal_rank)
|
raise ValueError("expected signal to have rank 2 but was " + signal_rank)
|
||||||
|
|
||||||
signal_length = array_ops.shape(signal)[1]
|
signal_length = array_ops.shape(signal)[1]
|
||||||
|
|
||||||
num_frames = math_ops.ceil((signal_length - frame_length) / frame_step)
|
num_frames = math_ops.ceil((signal_length - frame_length) / frame_step)
|
||||||
num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
|
num_frames = 1 + math_ops.cast(num_frames, dtypes.int32)
|
||||||
|
|
||||||
pad_length = (num_frames - 1) * frame_step + frame_length
|
pad_length = (num_frames - 1) * frame_step + frame_length
|
||||||
pad_signal = array_ops.pad(
|
pad_signal = array_ops.pad(signal, [[0, 0], [0,
|
||||||
signal, [[0, 0], [0, pad_length - signal_length]])
|
pad_length - signal_length]])
|
||||||
|
|
||||||
indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
|
indices_frame = array_ops.expand_dims(math_ops.range(frame_length), 0)
|
||||||
indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
|
indices_frames = array_ops.tile(indices_frame, [num_frames, 1])
|
||||||
|
|
||||||
indices_step = array_ops.expand_dims(
|
indices_step = array_ops.expand_dims(
|
||||||
math_ops.range(num_frames) * frame_step, 1)
|
math_ops.range(num_frames) * frame_step, 1)
|
||||||
indices_steps = array_ops.tile(indices_step, [1, frame_length])
|
indices_steps = array_ops.tile(indices_step, [1, frame_length])
|
||||||
|
|
||||||
indices = indices_frames + indices_steps
|
indices = indices_frames + indices_steps
|
||||||
|
|
||||||
# TODO(Androbin): remove `transpose` when `gather` gets `axis` support
|
# TODO(androbin): remove `transpose` when `gather` gets `axis` support
|
||||||
pad_signal = array_ops.transpose(pad_signal)
|
pad_signal = array_ops.transpose(pad_signal)
|
||||||
frames = array_ops.gather(pad_signal, indices)
|
signal_frames = array_ops.gather(pad_signal, indices)
|
||||||
frames = array_ops.transpose(frames, perm=[2, 0, 1])
|
signal_frames = array_ops.transpose(signal_frames, perm=[2, 0, 1])
|
||||||
|
|
||||||
return frames
|
return signal_frames
|
||||||
|
@ -127,6 +127,6 @@ class FakeSummaryWriter(object):
|
|||||||
|
|
||||||
def reopen(self):
|
def reopen(self):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def close(self):
|
def close(self):
|
||||||
pass
|
pass
|
||||||
|
@ -97,6 +97,29 @@ py_test(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
py_library(
|
||||||
|
name = "pprof_profiler",
|
||||||
|
srcs = ["pprof_profiler.py"],
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
deps = ["@pprof_profile_proto//:pprof_proto_py"],
|
||||||
|
)
|
||||||
|
|
||||||
|
py_test(
|
||||||
|
name = "pprof_profiler_test",
|
||||||
|
srcs = ["pprof_profiler_test.py"],
|
||||||
|
main = "pprof_profiler_test.py",
|
||||||
|
srcs_version = "PY2AND3",
|
||||||
|
tags = ["no_pip"], # TODO(annarev): get it working with pip.
|
||||||
|
deps = [
|
||||||
|
":pprof_profiler",
|
||||||
|
"//tensorflow/python:client",
|
||||||
|
"//tensorflow/python:client_testlib",
|
||||||
|
"//tensorflow/python:framework_test_lib",
|
||||||
|
"//tensorflow/python:platform_test",
|
||||||
|
"@pprof_profile_proto//:pprof_proto_py",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# -----------------------------------------------------------------------------
|
# -----------------------------------------------------------------------------
|
||||||
# Google-internal targets. These must be at the end for syncrepo.
|
# Google-internal targets. These must be at the end for syncrepo.
|
||||||
|
|
||||||
|
445
tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
Normal file
445
tensorflow/contrib/tfprof/python/tools/tfprof/pprof_profiler.py
Normal file
@ -0,0 +1,445 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Profiler for TensorFlow models that outputs data in pprof format.
|
||||||
|
|
||||||
|
See https://github.com/google/pprof/blob/master/proto/profile.proto for pprof
|
||||||
|
profile format.
|
||||||
|
The following needs to be set for profiler to work:
|
||||||
|
* trace_level needs to be set to FULL_TRACE
|
||||||
|
* run_metadata object should be passed in to session.run call
|
||||||
|
|
||||||
|
Sample usage:
|
||||||
|
options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE)
|
||||||
|
run_metadata = tf.RunMetadata()
|
||||||
|
|
||||||
|
with tf.Session as sess:
|
||||||
|
...
|
||||||
|
sess.run(computation, run_metadata=run_metadata, options=options)
|
||||||
|
pprof_profiler.profile(sess.graph, run_metadata, output_dir)
|
||||||
|
|
||||||
|
|
||||||
|
The code above would output a pprof profile to separate output_dir/.*.pb.gz
|
||||||
|
file for each device. These files can be passed to pprof for formatting.
|
||||||
|
For e.g.:
|
||||||
|
pprof -png --nodecount=100 --sample_index=1 output_dir/profile_output.pb.gz
|
||||||
|
"""
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
from collections import defaultdict
|
||||||
|
from collections import namedtuple
|
||||||
|
import gzip
|
||||||
|
import os
|
||||||
|
import string
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
|
||||||
|
from pprof import profile_pb2
|
||||||
|
|
||||||
|
|
||||||
|
if sys.version_info < (3,):
|
||||||
|
maketrans = string.maketrans
|
||||||
|
else:
|
||||||
|
maketrans = str.maketrans
|
||||||
|
|
||||||
|
|
||||||
|
ProfileDatum = namedtuple('ProfileDatum', [
|
||||||
|
'node_exec_stats', 'op_type', 'traceback'])
|
||||||
|
|
||||||
|
|
||||||
|
class StringTable(object):
|
||||||
|
"""Keeps track of strings to add to string_table in pprof proto."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
# Pprof requires first entry in string_table to be ''.
|
||||||
|
self._string_table = ['']
|
||||||
|
self._string_to_index = {'': 0}
|
||||||
|
|
||||||
|
def index_of(self, value_str):
|
||||||
|
"""Get index of value_str in the string table.
|
||||||
|
|
||||||
|
If value_str is not in the string table, we will add it at the end
|
||||||
|
and then return the new index.
|
||||||
|
Args:
|
||||||
|
value_str: (string) Value to lookup/add in/to the string table.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Index of value_str in the string table.
|
||||||
|
"""
|
||||||
|
if value_str is None:
|
||||||
|
value_str = ''
|
||||||
|
if value_str in self._string_to_index:
|
||||||
|
return self._string_to_index[value_str]
|
||||||
|
index = len(self._string_table)
|
||||||
|
self._string_table.append(value_str)
|
||||||
|
self._string_to_index[value_str] = index
|
||||||
|
return index
|
||||||
|
|
||||||
|
def next_index(self):
|
||||||
|
"""Gets index that would be assigned to the next added string.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Index of the next string if it was added.
|
||||||
|
"""
|
||||||
|
return len(self._string_table)
|
||||||
|
|
||||||
|
def string_table(self):
|
||||||
|
"""Returns a list of strings to store in pprof's string_table."""
|
||||||
|
return self._string_table
|
||||||
|
|
||||||
|
|
||||||
|
class Functions(object):
|
||||||
|
"""Keeps track of `Function` protos for pprof profile."""
|
||||||
|
|
||||||
|
def __init__(self, string_table):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
string_table: A `StringTable` object.
|
||||||
|
"""
|
||||||
|
self._string_table = string_table
|
||||||
|
# Maps tuples in the form (file_path, function_name, start_line_number)
|
||||||
|
# to `Function` protos.
|
||||||
|
self._function_key_to_function = {}
|
||||||
|
|
||||||
|
def index_of(self, file_path, function_name, function_start_line):
|
||||||
|
"""Returns index of the function, adding the function if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: (string) Path to file where the function is defined.
|
||||||
|
function_name: (string) Function name.
|
||||||
|
function_start_line: (integer) Start line number of function definition.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Function index.
|
||||||
|
"""
|
||||||
|
function_key = (file_path, function_name, function_start_line)
|
||||||
|
if function_key in self._function_key_to_function:
|
||||||
|
return self._function_key_to_function[function_key].id
|
||||||
|
else:
|
||||||
|
# Function indexes should start from 1
|
||||||
|
function_index = len(self._function_key_to_function) + 1
|
||||||
|
function = profile_pb2.Function()
|
||||||
|
function.id = function_index
|
||||||
|
function.name = self._string_table.index_of(function_name)
|
||||||
|
function.filename = self._string_table.index_of(file_path)
|
||||||
|
function.start_line = function_start_line
|
||||||
|
self._function_key_to_function[function_key] = function
|
||||||
|
return function_index
|
||||||
|
|
||||||
|
def function_protos(self):
|
||||||
|
"""Returns list of `profile_pb2.Function` protos."""
|
||||||
|
return self._function_key_to_function.values()
|
||||||
|
|
||||||
|
|
||||||
|
class Locations(object):
|
||||||
|
"""Keeps track of `Location` protos for pprof profile.
|
||||||
|
|
||||||
|
`Locations` store information about function call locations.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, functions):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
functions: A `Functions` object.
|
||||||
|
"""
|
||||||
|
self._functions = functions
|
||||||
|
# Maps tuples in the form (file_path, called_function_name, line_number)
|
||||||
|
# to `Location` protos.
|
||||||
|
self._location_key_to_location = {}
|
||||||
|
|
||||||
|
def index_of(
|
||||||
|
self, file_path, line_number, called_function_name, called_file_path,
|
||||||
|
called_function_start_line):
|
||||||
|
"""Returns index of the location, adding the location if needed.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
file_path: (string) Path to file that makes the call.
|
||||||
|
line_number: (integer) Call line number.
|
||||||
|
called_function_name: (string) Function name of the function called at
|
||||||
|
`file_path` and `line_number`.
|
||||||
|
called_file_path: (string) Path to file where the called function is
|
||||||
|
defined.
|
||||||
|
called_function_start_line: (integer) Start line number of called
|
||||||
|
function definition in `called_file_path` file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Index of location.
|
||||||
|
"""
|
||||||
|
location_key = (file_path, called_function_name, line_number)
|
||||||
|
if location_key in self._location_key_to_location:
|
||||||
|
location = self._location_key_to_location[location_key]
|
||||||
|
return location.id
|
||||||
|
else:
|
||||||
|
# Location indexes should start from 1
|
||||||
|
location_index = len(self._location_key_to_location) + 1
|
||||||
|
location = profile_pb2.Location()
|
||||||
|
location.id = location_index
|
||||||
|
self._location_key_to_location[location_key] = location
|
||||||
|
|
||||||
|
line = location.line.add()
|
||||||
|
line.function_id = self._functions.index_of(
|
||||||
|
called_file_path, called_function_name, called_function_start_line)
|
||||||
|
line.line = line_number
|
||||||
|
return location_index
|
||||||
|
|
||||||
|
def location_protos(self):
|
||||||
|
"""Returns list of `profile_pb2.Location` protos."""
|
||||||
|
return self._location_key_to_location.values()
|
||||||
|
|
||||||
|
|
||||||
|
class Samples(object):
|
||||||
|
"""Keeps track of `Sample` protos for pprof profile.
|
||||||
|
|
||||||
|
Samples store the following statistics in order:
|
||||||
|
count, all_time, op_time
|
||||||
|
"""
|
||||||
|
|
||||||
|
def __init__(self, string_table):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
string_table: A `StringTable` object.
|
||||||
|
"""
|
||||||
|
self._string_table = string_table
|
||||||
|
# TODO(annarev): figure out if location is unique for each node name.
|
||||||
|
# If not, also key this dictionary based on location ids.
|
||||||
|
self._node_name_to_sample = {}
|
||||||
|
|
||||||
|
def add(self, datum, location_ids):
|
||||||
|
"""Adds a sample data point.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
datum: `ProfileDatum` to add a sample for.
|
||||||
|
location_ids: List of numberic location ids for this
|
||||||
|
sample.
|
||||||
|
"""
|
||||||
|
node_name = datum.node_exec_stats.node_name
|
||||||
|
if node_name in self._node_name_to_sample:
|
||||||
|
sample = self._node_name_to_sample[node_name]
|
||||||
|
sample.location_id.extend(location_ids)
|
||||||
|
else:
|
||||||
|
sample = profile_pb2.Sample()
|
||||||
|
# Sample stores 3 values: count, all_time, op_time
|
||||||
|
sample.value.extend([0, 0, 0])
|
||||||
|
|
||||||
|
label = sample.label.add()
|
||||||
|
label.key = self._string_table.index_of('node_name')
|
||||||
|
label.str = self._string_table.index_of(node_name)
|
||||||
|
label = sample.label.add()
|
||||||
|
label.key = self._string_table.index_of('op_type')
|
||||||
|
label.str = self._string_table.index_of(datum.op_type)
|
||||||
|
self._node_name_to_sample[node_name] = sample
|
||||||
|
sample.value[0] += 1
|
||||||
|
sample.value[1] += datum.node_exec_stats.all_end_rel_micros
|
||||||
|
sample.value[2] += (
|
||||||
|
datum.node_exec_stats.op_end_rel_micros -
|
||||||
|
datum.node_exec_stats.op_start_rel_micros)
|
||||||
|
|
||||||
|
def get_sample_protos(self):
|
||||||
|
"""Returns list of `Sample` protos for pprof profile."""
|
||||||
|
return self._node_name_to_sample.values()
|
||||||
|
|
||||||
|
|
||||||
|
class PprofProfiler(object):
|
||||||
|
"""Creates profiles in pprof format."""
|
||||||
|
|
||||||
|
def __init__(self, graph, run_metadata):
|
||||||
|
"""Constructor.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: A `Graph` instance.
|
||||||
|
run_metadata: A list of `RunMetadata` objects.
|
||||||
|
"""
|
||||||
|
self._graph = graph
|
||||||
|
self._run_metadata = run_metadata
|
||||||
|
self._string_table = StringTable()
|
||||||
|
self._functions = Functions(self._string_table)
|
||||||
|
self._locations = Locations(self._functions)
|
||||||
|
|
||||||
|
def profile(self):
|
||||||
|
"""Generates pprof profiles.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dictionary mapping from device name to proto in `profile_pb2.Profile`
|
||||||
|
format.
|
||||||
|
"""
|
||||||
|
profiles = {}
|
||||||
|
data_generator_func = self._get_profile_data_generator()
|
||||||
|
for device_index, device_stats in enumerate(
|
||||||
|
self._run_metadata.step_stats.dev_stats):
|
||||||
|
# Create profile
|
||||||
|
pprof_proto = self._get_pprof_proto(data_generator_func(device_stats))
|
||||||
|
if not pprof_proto.sample:
|
||||||
|
print(
|
||||||
|
'Not enough data to create profile for device %s. Did you pass '
|
||||||
|
'RunMetadata to session.run call?' % device_stats.device)
|
||||||
|
continue
|
||||||
|
# Add device name comment
|
||||||
|
device_count = len(self._run_metadata.step_stats.dev_stats)
|
||||||
|
device_description = (
|
||||||
|
'Device %d of %d: %s' %
|
||||||
|
(device_index + 1, device_count, device_stats.device))
|
||||||
|
device_description_str_index = self._string_table.next_index()
|
||||||
|
pprof_proto.string_table.append(device_description)
|
||||||
|
pprof_proto.comment.append(device_description_str_index)
|
||||||
|
profiles[device_stats.device] = pprof_proto
|
||||||
|
return profiles
|
||||||
|
|
||||||
|
def _get_pprof_proto(self, profile_datum_generator):
|
||||||
|
"""Returns profile data in pprof proto format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
profile_datum_generator: Generator outputting `ProfileDatum` objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A proto in pprof format.
|
||||||
|
"""
|
||||||
|
pprof_profile = profile_pb2.Profile()
|
||||||
|
samples = Samples(self._string_table)
|
||||||
|
|
||||||
|
for datum in profile_datum_generator:
|
||||||
|
if not datum.traceback:
|
||||||
|
continue
|
||||||
|
|
||||||
|
stack_frame = datum.traceback[-1]
|
||||||
|
after_apply_op = False
|
||||||
|
location_ids = []
|
||||||
|
|
||||||
|
# We add locations from stack trace in bottom-up order.
|
||||||
|
for stack_frame_index in reversed(range(len(datum.traceback) - 1)):
|
||||||
|
prev_stack_frame = stack_frame
|
||||||
|
stack_frame = datum.traceback[stack_frame_index]
|
||||||
|
|
||||||
|
# Call at current frame calls function at previous frame.
|
||||||
|
prev_file_path = prev_stack_frame[0]
|
||||||
|
prev_function = prev_stack_frame[2]
|
||||||
|
prev_function_start_line = prev_stack_frame[4]
|
||||||
|
curr_file_path = stack_frame[0]
|
||||||
|
curr_line_number = stack_frame[1]
|
||||||
|
|
||||||
|
# Skip all calls up to apply_op since they are the same for all ops.
|
||||||
|
if not after_apply_op:
|
||||||
|
if prev_function == 'apply_op':
|
||||||
|
after_apply_op = True
|
||||||
|
continue
|
||||||
|
location_index = self._locations.index_of(
|
||||||
|
curr_file_path, curr_line_number,
|
||||||
|
prev_function, prev_file_path, prev_function_start_line)
|
||||||
|
location_ids.append(location_index)
|
||||||
|
samples.add(datum, location_ids)
|
||||||
|
|
||||||
|
sample_type_description = 'count'
|
||||||
|
sample_type = pprof_profile.sample_type.add()
|
||||||
|
sample_type.type = self._string_table.index_of(sample_type_description)
|
||||||
|
sample_type.unit = self._string_table.index_of('count')
|
||||||
|
sample_type_description = 'all_time'
|
||||||
|
sample_type = pprof_profile.sample_type.add()
|
||||||
|
sample_type.type = self._string_table.index_of(sample_type_description)
|
||||||
|
sample_type.unit = self._string_table.index_of('nanoseconds')
|
||||||
|
sample_type_description = 'op_time'
|
||||||
|
sample_type = pprof_profile.sample_type.add()
|
||||||
|
sample_type.type = self._string_table.index_of(sample_type_description)
|
||||||
|
sample_type.unit = self._string_table.index_of('nanoseconds')
|
||||||
|
|
||||||
|
pprof_profile.string_table.extend(self._string_table.string_table())
|
||||||
|
pprof_profile.sample.extend(samples.get_sample_protos())
|
||||||
|
pprof_profile.function.extend(self._functions.function_protos())
|
||||||
|
pprof_profile.location.extend(self._locations.location_protos())
|
||||||
|
return pprof_profile
|
||||||
|
|
||||||
|
def _get_profile_data_generator(self):
|
||||||
|
"""Get function that generates `ProfileDatum` objects.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A function that generates `ProfileDatum` objects.
|
||||||
|
"""
|
||||||
|
node_to_traceback = defaultdict(list)
|
||||||
|
node_to_op_type = defaultdict(str)
|
||||||
|
for op in self._graph.get_operations():
|
||||||
|
node_to_traceback[op.name] = op.traceback_with_start_lines
|
||||||
|
node_to_op_type[op.name] = op.type
|
||||||
|
|
||||||
|
def profile_data_generator(device_step_stats):
|
||||||
|
for node_stats in device_step_stats.node_stats:
|
||||||
|
if node_stats.node_name == '_SOURCE' or node_stats.node_name == '_SINK':
|
||||||
|
continue
|
||||||
|
yield ProfileDatum(
|
||||||
|
node_stats,
|
||||||
|
node_to_op_type[node_stats.node_name],
|
||||||
|
node_to_traceback[node_stats.node_name])
|
||||||
|
|
||||||
|
return profile_data_generator
|
||||||
|
|
||||||
|
|
||||||
|
def get_profiles(graph, run_metadata):
|
||||||
|
"""Generate profiles in pprof format.
|
||||||
|
|
||||||
|
See https://github.com/google/pprof/blob/master/proto/profile.proto
|
||||||
|
for pprof proto format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: A `Graph` object.
|
||||||
|
run_metadata: A `RunMetadata` proto.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping from device name to pprof proto for that device.
|
||||||
|
"""
|
||||||
|
return PprofProfiler(graph, run_metadata).profile()
|
||||||
|
|
||||||
|
|
||||||
|
def profile(graph, run_metadata, output_dir=None):
|
||||||
|
"""Generate profiles in pprof format.
|
||||||
|
|
||||||
|
See https://github.com/google/pprof/blob/master/proto/profile.proto
|
||||||
|
for pprof proto format.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
graph: A `Graph` object.
|
||||||
|
run_metadata: A `RunMetadata` proto.
|
||||||
|
output_dir: (string) Directory to output pprof profile to.
|
||||||
|
Profile files for each device will be stored in compressed
|
||||||
|
serialized proto format. If output_dir is None, profile protos
|
||||||
|
will be printed to stdout instead.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of output files created by this profile call.
|
||||||
|
(Note: this list will be empty if output_dir is None)
|
||||||
|
"""
|
||||||
|
profiles = get_profiles(graph, run_metadata)
|
||||||
|
output_file_template = None
|
||||||
|
if output_dir:
|
||||||
|
if not os.path.isdir(output_dir):
|
||||||
|
os.makedirs(output_dir)
|
||||||
|
time_suffix = time.strftime('%Y%m%d%H%M%S')
|
||||||
|
output_file_template = os.path.join(
|
||||||
|
output_dir, '%s_' + time_suffix + '.pb.gz')
|
||||||
|
|
||||||
|
profile_files = []
|
||||||
|
for device, pprof_proto in profiles.items():
|
||||||
|
if output_file_template is None:
|
||||||
|
print('No output directory specified, printing to stdout instead.')
|
||||||
|
print(pprof_proto)
|
||||||
|
else:
|
||||||
|
device_name = str(device).strip('/').translate(
|
||||||
|
maketrans('/:', '__'))
|
||||||
|
profile_file = output_file_template % device_name
|
||||||
|
profile_files.append(profile_file)
|
||||||
|
with gzip.open(profile_file, 'w') as output_file:
|
||||||
|
print('Writing profile to %s...' % profile_file)
|
||||||
|
output_file.write(pprof_proto.SerializeToString())
|
||||||
|
return profile_files
|
@ -0,0 +1,164 @@
|
|||||||
|
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
#
|
||||||
|
# Licensed under the Apache License, Version 2.0 (the 'License');
|
||||||
|
# you may not use this file except in compliance with the License.
|
||||||
|
# You may obtain a copy of the License at
|
||||||
|
#
|
||||||
|
# http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
#
|
||||||
|
# Unless required by applicable law or agreed to in writing, software
|
||||||
|
# distributed under the License is distributed on an 'AS IS' BASIS,
|
||||||
|
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
# See the License for the specific language governing permissions and
|
||||||
|
# limitations under the License.
|
||||||
|
# ==============================================================================
|
||||||
|
"""Tests for pprof_profiler."""
|
||||||
|
|
||||||
|
from __future__ import absolute_import
|
||||||
|
from __future__ import division
|
||||||
|
from __future__ import print_function
|
||||||
|
|
||||||
|
import gzip
|
||||||
|
|
||||||
|
from pprof import profile_pb2
|
||||||
|
from tensorflow.contrib.tfprof.python.tools.tfprof import pprof_profiler
|
||||||
|
from tensorflow.core.framework import step_stats_pb2
|
||||||
|
from tensorflow.core.protobuf import config_pb2
|
||||||
|
from tensorflow.python.framework import constant_op
|
||||||
|
from tensorflow.python.ops import control_flow_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
|
from tensorflow.python.platform import test
|
||||||
|
|
||||||
|
|
||||||
|
class PprofProfilerTest(test.TestCase):
|
||||||
|
|
||||||
|
def testDataEmpty(self):
|
||||||
|
output_dir = test.get_temp_dir()
|
||||||
|
run_metadata = config_pb2.RunMetadata()
|
||||||
|
graph = test.mock.MagicMock()
|
||||||
|
graph.get_operations.return_value = []
|
||||||
|
|
||||||
|
profiles = pprof_profiler.get_profiles(graph, run_metadata)
|
||||||
|
self.assertEquals(0, len(profiles))
|
||||||
|
profile_files = pprof_profiler.profile(
|
||||||
|
graph, run_metadata, output_dir)
|
||||||
|
self.assertEquals(0, len(profile_files))
|
||||||
|
|
||||||
|
def testRunMetadataEmpty(self):
|
||||||
|
output_dir = test.get_temp_dir()
|
||||||
|
run_metadata = config_pb2.RunMetadata()
|
||||||
|
graph = test.mock.MagicMock()
|
||||||
|
op1 = test.mock.MagicMock()
|
||||||
|
op1.name = 'Add/123'
|
||||||
|
op1.traceback = [('a/b/file1', 10, 'some_var')]
|
||||||
|
op1.type = 'add'
|
||||||
|
graph.get_operations.return_value = [op1]
|
||||||
|
|
||||||
|
profiles = pprof_profiler.get_profiles(graph, run_metadata)
|
||||||
|
self.assertEquals(0, len(profiles))
|
||||||
|
profile_files = pprof_profiler.profile(
|
||||||
|
graph, run_metadata, output_dir)
|
||||||
|
self.assertEquals(0, len(profile_files))
|
||||||
|
|
||||||
|
def testValidProfile(self):
|
||||||
|
output_dir = test.get_temp_dir()
|
||||||
|
run_metadata = config_pb2.RunMetadata()
|
||||||
|
|
||||||
|
node1 = step_stats_pb2.NodeExecStats(
|
||||||
|
node_name='Add/123',
|
||||||
|
op_start_rel_micros=3,
|
||||||
|
op_end_rel_micros=5,
|
||||||
|
all_end_rel_micros=4)
|
||||||
|
|
||||||
|
run_metadata = config_pb2.RunMetadata()
|
||||||
|
device1 = run_metadata.step_stats.dev_stats.add()
|
||||||
|
device1.device = 'deviceA'
|
||||||
|
device1.node_stats.extend([node1])
|
||||||
|
|
||||||
|
graph = test.mock.MagicMock()
|
||||||
|
op1 = test.mock.MagicMock()
|
||||||
|
op1.name = 'Add/123'
|
||||||
|
op1.traceback = [
|
||||||
|
('a/b/file1', 10, 'apply_op', 'abc'), ('a/c/file2', 12, 'my_op', 'def')]
|
||||||
|
op1.type = 'add'
|
||||||
|
graph.get_operations.return_value = [op1]
|
||||||
|
|
||||||
|
expected_proto = """sample_type {
|
||||||
|
type: 5
|
||||||
|
unit: 5
|
||||||
|
}
|
||||||
|
sample_type {
|
||||||
|
type: 6
|
||||||
|
unit: 7
|
||||||
|
}
|
||||||
|
sample_type {
|
||||||
|
type: 8
|
||||||
|
unit: 7
|
||||||
|
}
|
||||||
|
sample {
|
||||||
|
value: 1
|
||||||
|
value: 4
|
||||||
|
value: 2
|
||||||
|
label {
|
||||||
|
key: 1
|
||||||
|
str: 2
|
||||||
|
}
|
||||||
|
label {
|
||||||
|
key: 3
|
||||||
|
str: 4
|
||||||
|
}
|
||||||
|
}
|
||||||
|
string_table: ""
|
||||||
|
string_table: "node_name"
|
||||||
|
string_table: "Add/123"
|
||||||
|
string_table: "op_type"
|
||||||
|
string_table: "add"
|
||||||
|
string_table: "count"
|
||||||
|
string_table: "all_time"
|
||||||
|
string_table: "nanoseconds"
|
||||||
|
string_table: "op_time"
|
||||||
|
string_table: "Device 1 of 1: deviceA"
|
||||||
|
comment: 9
|
||||||
|
"""
|
||||||
|
# Test with protos
|
||||||
|
profiles = pprof_profiler.get_profiles(graph, run_metadata)
|
||||||
|
self.assertEquals(1, len(profiles))
|
||||||
|
self.assertTrue('deviceA' in profiles)
|
||||||
|
self.assertEquals(expected_proto, str(profiles['deviceA']))
|
||||||
|
# Test with files
|
||||||
|
profile_files = pprof_profiler.profile(
|
||||||
|
graph, run_metadata, output_dir)
|
||||||
|
self.assertEquals(1, len(profile_files))
|
||||||
|
with gzip.open(profile_files[0]) as profile_file:
|
||||||
|
profile_contents = profile_file.read()
|
||||||
|
profile = profile_pb2.Profile()
|
||||||
|
profile.ParseFromString(profile_contents)
|
||||||
|
self.assertEquals(expected_proto, str(profile))
|
||||||
|
|
||||||
|
def testProfileWithWhileLoop(self):
|
||||||
|
options = config_pb2.RunOptions()
|
||||||
|
options.trace_level = config_pb2.RunOptions.FULL_TRACE
|
||||||
|
run_metadata = config_pb2.RunMetadata()
|
||||||
|
|
||||||
|
num_iters = 5
|
||||||
|
with self.test_session() as sess:
|
||||||
|
i = constant_op.constant(0)
|
||||||
|
c = lambda i: math_ops.less(i, num_iters)
|
||||||
|
b = lambda i: math_ops.add(i, 1)
|
||||||
|
r = control_flow_ops.while_loop(c, b, [i])
|
||||||
|
sess.run(r, options=options, run_metadata=run_metadata)
|
||||||
|
profiles = pprof_profiler.get_profiles(sess.graph, run_metadata)
|
||||||
|
self.assertEquals(1, len(profiles))
|
||||||
|
profile = next(iter(profiles.values()))
|
||||||
|
add_samples = [] # Samples for the while/Add node
|
||||||
|
for sample in profile.sample:
|
||||||
|
if profile.string_table[sample.label[0].str] == 'while/Add':
|
||||||
|
add_samples.append(sample)
|
||||||
|
# Values for same nodes are aggregated.
|
||||||
|
self.assertEquals(1, len(add_samples))
|
||||||
|
# Value of "count" should be equal to number of iterations.
|
||||||
|
self.assertEquals(num_iters, add_samples[0].value[0])
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == '__main__':
|
||||||
|
test.main()
|
@ -272,7 +272,8 @@ RdmaChannel::RdmaChannel(const RdmaAdapter* adapter, const string local_name,
|
|||||||
self_.qpn = qp_->qp_num;
|
self_.qpn = qp_->qp_num;
|
||||||
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
|
self_.psn = static_cast<uint32_t>(random::New64()) & 0xffffff;
|
||||||
union ibv_gid gid;
|
union ibv_gid gid;
|
||||||
CHECK(!ibv_query_gid(adapter_->context_, (uint8_t) 1, 0, &gid)) << "Query gid";
|
CHECK(!ibv_query_gid(adapter_->context_, (uint8_t)1, 0, &gid))
|
||||||
|
<< "Query gid";
|
||||||
self_.snp = gid.global.subnet_prefix;
|
self_.snp = gid.global.subnet_prefix;
|
||||||
self_.iid = gid.global.interface_id;
|
self_.iid = gid.global.interface_id;
|
||||||
}
|
}
|
||||||
@ -479,7 +480,7 @@ void RdmaChannel::Connect(const RdmaAddress& remoteAddr) {
|
|||||||
attr.dest_qp_num = remoteAddr.qpn;
|
attr.dest_qp_num = remoteAddr.qpn;
|
||||||
attr.rq_psn = remoteAddr.psn;
|
attr.rq_psn = remoteAddr.psn;
|
||||||
attr.max_dest_rd_atomic = 1;
|
attr.max_dest_rd_atomic = 1;
|
||||||
attr.min_rnr_timer = 12;
|
attr.min_rnr_timer = 12;
|
||||||
attr.ah_attr.is_global = 1;
|
attr.ah_attr.is_global = 1;
|
||||||
attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp;
|
attr.ah_attr.grh.dgid.global.subnet_prefix = remoteAddr.snp;
|
||||||
attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid;
|
attr.ah_attr.grh.dgid.global.interface_id = remoteAddr.iid;
|
||||||
|
@ -248,8 +248,8 @@ void EncodeTensorToByteBuffer(bool is_dead, const Tensor& val,
|
|||||||
tdata.size(), do_nothing);
|
tdata.size(), do_nothing);
|
||||||
slices[1] = ::grpc::Slice(s1, ::grpc::Slice::STEAL_REF);
|
slices[1] = ::grpc::Slice(s1, ::grpc::Slice::STEAL_REF);
|
||||||
|
|
||||||
gpr_slice s2 = gpr_slice_new(const_cast<TensorBuffer*>(buf),
|
gpr_slice s2 =
|
||||||
0, unref_tensorbuffer);
|
gpr_slice_new(const_cast<TensorBuffer*>(buf), 0, unref_tensorbuffer);
|
||||||
slices[2] = ::grpc::Slice(s2, ::grpc::Slice::STEAL_REF);
|
slices[2] = ::grpc::Slice(s2, ::grpc::Slice::STEAL_REF);
|
||||||
num_slices += 2;
|
num_slices += 2;
|
||||||
}
|
}
|
||||||
|
@ -135,6 +135,22 @@ cc_library(
|
|||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "virtual_placer",
|
||||||
|
srcs = ["virtual_placer.cc"],
|
||||||
|
hdrs = ["virtual_placer.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":op_performance_data_cc",
|
||||||
|
":utils",
|
||||||
|
"//tensorflow/core:framework",
|
||||||
|
"//tensorflow/core:framework_lite",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/grappler:devices",
|
||||||
|
"//tensorflow/core/grappler/clusters:cluster",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
cc_library(
|
cc_library(
|
||||||
name = "virtual_scheduler",
|
name = "virtual_scheduler",
|
||||||
srcs = ["virtual_scheduler.cc"],
|
srcs = ["virtual_scheduler.cc"],
|
||||||
@ -194,3 +210,24 @@ cc_test(
|
|||||||
"//tensorflow/core:test_main",
|
"//tensorflow/core:test_main",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
cc_library(
|
||||||
|
name = "analytical_cost_estimator",
|
||||||
|
srcs = ["analytical_cost_estimator.cc"],
|
||||||
|
hdrs = ["analytical_cost_estimator.h"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":cost_estimator",
|
||||||
|
":graph_properties",
|
||||||
|
":op_level_cost_estimator",
|
||||||
|
":op_performance_data_cc",
|
||||||
|
":utils",
|
||||||
|
":virtual_placer",
|
||||||
|
":virtual_scheduler",
|
||||||
|
"//tensorflow/core:core_cpu",
|
||||||
|
"//tensorflow/core:core_cpu_internal",
|
||||||
|
"//tensorflow/core:lib",
|
||||||
|
"//tensorflow/core:protos_all_cc",
|
||||||
|
"//tensorflow/core/grappler:grappler_item",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
128
tensorflow/core/grappler/costs/analytical_cost_estimator.cc
Normal file
128
tensorflow/core/grappler/costs/analytical_cost_estimator.cc
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/costs/analytical_cost_estimator.h"
|
||||||
|
|
||||||
|
#include <limits>
|
||||||
|
#include <unordered_map>
|
||||||
|
|
||||||
|
#include "tensorflow/core/framework/attr_value.pb.h"
|
||||||
|
#include "tensorflow/core/graph/types.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/graph_properties.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/utils.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/virtual_scheduler.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/public/session.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
AnalyticalCostEstimator::AnalyticalCostEstimator(Cluster* cluster,
|
||||||
|
bool use_static_shapes)
|
||||||
|
: cluster_(cluster), use_static_shapes_(use_static_shapes) {}
|
||||||
|
|
||||||
|
Status AnalyticalCostEstimator::Initialize(const GrapplerItem& item) {
|
||||||
|
item_ = item;
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
Status AnalyticalCostEstimator::PredictCosts(const GraphDef& optimized_graph,
|
||||||
|
CostGraphDef* cost_graph,
|
||||||
|
Costs* costs) const {
|
||||||
|
GrapplerItem item = item_;
|
||||||
|
item.graph = optimized_graph;
|
||||||
|
GraphProperties properties(item);
|
||||||
|
Status status;
|
||||||
|
if (use_static_shapes_) {
|
||||||
|
status = properties.InferStatically();
|
||||||
|
} else {
|
||||||
|
status = properties.InferDynamically(cluster_);
|
||||||
|
}
|
||||||
|
|
||||||
|
if (!status.ok()) {
|
||||||
|
costs->execution_time = Costs::Duration::max();
|
||||||
|
return status;
|
||||||
|
}
|
||||||
|
|
||||||
|
std::unordered_map<string, CostGraphDef::Node*> name_to_cost;
|
||||||
|
if (cost_graph) {
|
||||||
|
for (auto& node : *cost_graph->mutable_node()) {
|
||||||
|
name_to_cost[node.name()] = &node;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
std::vector<string> inaccurate_nodes;
|
||||||
|
VirtualScheduler scheduler(optimized_graph, item_.fetch);
|
||||||
|
VirtualPlacer placer(cluster_);
|
||||||
|
Costs node_costs;
|
||||||
|
do {
|
||||||
|
const NodeDef* node = scheduler.GetCurrNode();
|
||||||
|
std::vector<OpInfo::TensorProperties> inputs =
|
||||||
|
properties.GetInputProperties(node->name());
|
||||||
|
|
||||||
|
OpInfo::DeviceProperties device = placer.get_device(*node);
|
||||||
|
OpInfo op_info;
|
||||||
|
op_info.set_op(node->op());
|
||||||
|
*op_info.mutable_attr() = node->attr();
|
||||||
|
for (auto& input : inputs) {
|
||||||
|
op_info.add_inputs()->Swap(&input);
|
||||||
|
}
|
||||||
|
op_info.mutable_device()->Swap(&device);
|
||||||
|
|
||||||
|
node_costs = node_estimator_.PredictCosts(op_info);
|
||||||
|
if (node_costs.inaccurate) {
|
||||||
|
inaccurate_nodes.push_back(node->name());
|
||||||
|
}
|
||||||
|
if (cost_graph) {
|
||||||
|
auto it = name_to_cost.find(node->name());
|
||||||
|
CostGraphDef::Node* cost_node;
|
||||||
|
if (it != name_to_cost.end()) {
|
||||||
|
cost_node = it->second;
|
||||||
|
} else {
|
||||||
|
cost_node = cost_graph->add_node();
|
||||||
|
cost_node->set_name(node->name());
|
||||||
|
}
|
||||||
|
string device_name = properties.GetDeviceName(node->name());
|
||||||
|
cost_node->set_device(device_name);
|
||||||
|
cost_node->set_compute_cost(
|
||||||
|
node_costs.execution_time.asMicroSeconds().count());
|
||||||
|
cost_node->set_compute_time(
|
||||||
|
node_costs.compute_time.asMicroSeconds().count());
|
||||||
|
cost_node->set_memory_time(
|
||||||
|
node_costs.memory_time.asMicroSeconds().count());
|
||||||
|
std::vector<OpInfo::TensorProperties> outputs =
|
||||||
|
properties.GetOutputProperties(node->name());
|
||||||
|
for (const auto& output : outputs) {
|
||||||
|
auto output_info = cost_node->add_output_info();
|
||||||
|
output_info->set_dtype(output.dtype());
|
||||||
|
auto shape = output_info->mutable_shape();
|
||||||
|
*shape = output.shape();
|
||||||
|
}
|
||||||
|
}
|
||||||
|
} while (scheduler.MarkCurrNodeExecuted(node_costs));
|
||||||
|
|
||||||
|
*costs = scheduler.Summary();
|
||||||
|
VLOG(1) << inaccurate_nodes.size() << " out of "
|
||||||
|
<< optimized_graph.node_size()
|
||||||
|
<< " nodes have inaccurate time estimation";
|
||||||
|
for (const auto& node : inaccurate_nodes) {
|
||||||
|
VLOG(2) << "Node with inaccurate time estimation: " << node;
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
63
tensorflow/core/grappler/costs/analytical_cost_estimator.h
Normal file
63
tensorflow/core/grappler/costs/analytical_cost_estimator.h
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
|
||||||
|
#define TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/costs/cost_estimator.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/op_level_cost_estimator.h"
|
||||||
|
#include "tensorflow/core/grappler/grappler_item.h"
|
||||||
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
class CostGraphDef;
|
||||||
|
class GraphDef;
|
||||||
|
} // namespace tensorflow
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
class Cluster;
|
||||||
|
struct GrapplerItem;
|
||||||
|
|
||||||
|
// Estimate the cost of running a Grappler item based on the theoretical
|
||||||
|
// performance of the hardware that will run the model.
|
||||||
|
class AnalyticalCostEstimator : public CostEstimator {
|
||||||
|
public:
|
||||||
|
// Does not take ownership of cluster.
|
||||||
|
explicit AnalyticalCostEstimator(Cluster* cluster, bool use_static_shapes);
|
||||||
|
~AnalyticalCostEstimator() override {}
|
||||||
|
|
||||||
|
// Initalizes the estimator for the specified grappler item.
|
||||||
|
// This implementation always returns OK.
|
||||||
|
Status Initialize(const GrapplerItem& item) override;
|
||||||
|
|
||||||
|
// Predict the performance of each node of the optimized graph and annotate
|
||||||
|
// the CostGraphDef with the corresponding estimates. Also returns the
|
||||||
|
// expected latency for the whole graph.
|
||||||
|
Status PredictCosts(const GraphDef& optimized_graph, CostGraphDef* cost_graph,
|
||||||
|
Costs* overall_latency) const override;
|
||||||
|
|
||||||
|
private:
|
||||||
|
Cluster* cluster_; // Not owned.
|
||||||
|
GrapplerItem item_;
|
||||||
|
OpLevelCostEstimator node_estimator_;
|
||||||
|
bool use_static_shapes_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_ANALYTICAL_COST_ESTIMATOR_H_
|
@ -80,7 +80,8 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
|
|||||||
const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo();
|
const OpInfo::DeviceProperties local_cpu = GetLocalCPUInfo();
|
||||||
// Check if vector instructions are available, and refine performance
|
// Check if vector instructions are available, and refine performance
|
||||||
// prediction based on this.
|
// prediction based on this.
|
||||||
gflops = local_cpu.num_cores() * local_cpu.frequency();
|
// Frequencies are stored in MHz in the DeviceProperties.
|
||||||
|
gflops = local_cpu.num_cores() * local_cpu.frequency() * 1e-3;
|
||||||
if (bandwidth < 0) {
|
if (bandwidth < 0) {
|
||||||
if (local_cpu.bandwidth() > 0) {
|
if (local_cpu.bandwidth() > 0) {
|
||||||
bandwidth = local_cpu.bandwidth() / 1e6;
|
bandwidth = local_cpu.bandwidth() / 1e6;
|
||||||
@ -105,7 +106,7 @@ std::pair<double, double> OpLevelCostEstimator::GetDeviceInfo(
|
|||||||
// Pascal.
|
// Pascal.
|
||||||
cores_per_multiprocessor = 64;
|
cores_per_multiprocessor = 64;
|
||||||
}
|
}
|
||||||
gflops = local_gpu.num_cores() * local_gpu.frequency() *
|
gflops = local_gpu.num_cores() * local_gpu.frequency() * 1e-3 *
|
||||||
cores_per_multiprocessor * kOpsPerMac;
|
cores_per_multiprocessor * kOpsPerMac;
|
||||||
if (bandwidth < 0) {
|
if (bandwidth < 0) {
|
||||||
CHECK(local_gpu.bandwidth() > 0);
|
CHECK(local_gpu.bandwidth() > 0);
|
||||||
|
@ -147,7 +147,7 @@ OpInfo::DeviceProperties GetLocalCPUInfo() {
|
|||||||
// Combine cpu family and model into the model string.
|
// Combine cpu family and model into the model string.
|
||||||
device.set_model(
|
device.set_model(
|
||||||
strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
|
strings::StrCat((port::CPUFamily() << 4) + port::CPUModelNum()));
|
||||||
device.set_frequency(port::NominalCPUFrequency() * 1e-9);
|
device.set_frequency(port::NominalCPUFrequency() * 1e-6);
|
||||||
device.set_num_cores(port::NumSchedulableCPUs());
|
device.set_num_cores(port::NumSchedulableCPUs());
|
||||||
device.set_l1_cache_size(Eigen::l1CacheSize());
|
device.set_l1_cache_size(Eigen::l1CacheSize());
|
||||||
device.set_l2_cache_size(Eigen::l2CacheSize());
|
device.set_l2_cache_size(Eigen::l2CacheSize());
|
||||||
@ -175,7 +175,7 @@ OpInfo::DeviceProperties GetLocalGPUInfo(int gpu_id) {
|
|||||||
if (error == cudaSuccess) {
|
if (error == cudaSuccess) {
|
||||||
device.set_vendor("NVidia");
|
device.set_vendor("NVidia");
|
||||||
device.set_model(properties.name);
|
device.set_model(properties.name);
|
||||||
device.set_frequency(properties.clockRate / 1000);
|
device.set_frequency(properties.clockRate * 1e-3);
|
||||||
device.set_num_cores(properties.multiProcessorCount);
|
device.set_num_cores(properties.multiProcessorCount);
|
||||||
device.set_num_registers(properties.regsPerMultiprocessor);
|
device.set_num_registers(properties.regsPerMultiprocessor);
|
||||||
// For compute capability less than 5, l1 cache size is configurable to
|
// For compute capability less than 5, l1 cache size is configurable to
|
||||||
|
57
tensorflow/core/grappler/costs/virtual_placer.cc
Normal file
57
tensorflow/core/grappler/costs/virtual_placer.cc
Normal file
@ -0,0 +1,57 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#include "tensorflow/core/grappler/costs/virtual_placer.h"
|
||||||
|
#include "tensorflow/core/framework/node_def.pb.h"
|
||||||
|
#include "tensorflow/core/grappler/clusters/cluster.h"
|
||||||
|
#include "tensorflow/core/grappler/costs/utils.h"
|
||||||
|
#include "tensorflow/core/grappler/devices.h"
|
||||||
|
#include "tensorflow/core/util/device_name_utils.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
namespace grappler {
|
||||||
|
|
||||||
|
VirtualPlacer::VirtualPlacer(Cluster* cluster) : has_gpu_(false) {
|
||||||
|
devices_["CPU"] = GetLocalCPUInfo();
|
||||||
|
if (GetNumAvailableGPUs() > 0) {
|
||||||
|
has_gpu_ = true;
|
||||||
|
devices_["GPU"] = GetLocalGPUInfo(0);
|
||||||
|
}
|
||||||
|
unknown_device_.set_type("UNKNOWN");
|
||||||
|
}
|
||||||
|
|
||||||
|
const OpInfo::DeviceProperties& VirtualPlacer::get_device(
|
||||||
|
const NodeDef& node) const {
|
||||||
|
string device_type;
|
||||||
|
DeviceNameUtils::ParsedName parsed;
|
||||||
|
if (!node.device().empty() &&
|
||||||
|
DeviceNameUtils::ParseFullName(node.device(), &parsed)) {
|
||||||
|
device_type = parsed.type;
|
||||||
|
} else {
|
||||||
|
if (has_gpu_) {
|
||||||
|
device_type = "GPU";
|
||||||
|
} else {
|
||||||
|
device_type = "CPU";
|
||||||
|
}
|
||||||
|
}
|
||||||
|
auto it = devices_.find(device_type);
|
||||||
|
if (it == devices_.end()) {
|
||||||
|
return unknown_device_;
|
||||||
|
}
|
||||||
|
return it->second;
|
||||||
|
}
|
||||||
|
|
||||||
|
} // end namespace grappler
|
||||||
|
} // end namespace tensorflow
|
45
tensorflow/core/grappler/costs/virtual_placer.h
Normal file
45
tensorflow/core/grappler/costs/virtual_placer.h
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
|
||||||
|
|
||||||
|
Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
|
you may not use this file except in compliance with the License.
|
||||||
|
You may obtain a copy of the License at
|
||||||
|
|
||||||
|
http://www.apache.org/licenses/LICENSE-2.0
|
||||||
|
|
||||||
|
Unless required by applicable law or agreed to in writing, software
|
||||||
|
distributed under the License is distributed on an "AS IS" BASIS,
|
||||||
|
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||||
|
See the License for the specific language governing permissions and
|
||||||
|
limitations under the License.
|
||||||
|
==============================================================================*/
|
||||||
|
|
||||||
|
#ifndef TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
|
||||||
|
#define TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
|
||||||
|
|
||||||
|
#include <unordered_map>
|
||||||
|
#include "tensorflow/core/grappler/costs/op_performance_data.pb.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
|
namespace tensorflow {
|
||||||
|
class NodeDef;
|
||||||
|
|
||||||
|
namespace grappler {
|
||||||
|
class Cluster;
|
||||||
|
|
||||||
|
// The virtual placer emulates the behavior of the TF placer.
|
||||||
|
class VirtualPlacer {
|
||||||
|
public:
|
||||||
|
VirtualPlacer(Cluster* cluster);
|
||||||
|
|
||||||
|
const OpInfo::DeviceProperties& get_device(const NodeDef& node) const;
|
||||||
|
|
||||||
|
private:
|
||||||
|
std::unordered_map<string, OpInfo::DeviceProperties> devices_;
|
||||||
|
bool has_gpu_;
|
||||||
|
OpInfo::DeviceProperties unknown_device_;
|
||||||
|
};
|
||||||
|
|
||||||
|
} // namespace grappler
|
||||||
|
} // end namespace tensorflow
|
||||||
|
|
||||||
|
#endif // TENSORFLOW_CORE_GRAPPLER_COSTS_VIRTUAL_PLACER_H_
|
@ -19,6 +19,9 @@ limitations under the License.
|
|||||||
|
|
||||||
#include "tensorflow/core/kernels/crop_and_resize_op.h"
|
#include "tensorflow/core/kernels/crop_and_resize_op.h"
|
||||||
|
|
||||||
|
#include <functional>
|
||||||
|
#include <string>
|
||||||
|
|
||||||
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
#include "third_party/eigen3/unsupported/Eigen/CXX11/Tensor"
|
||||||
#include "tensorflow/core/framework/op_kernel.h"
|
#include "tensorflow/core/framework/op_kernel.h"
|
||||||
#include "tensorflow/core/framework/register_types.h"
|
#include "tensorflow/core/framework/register_types.h"
|
||||||
@ -26,10 +29,13 @@ limitations under the License.
|
|||||||
#include "tensorflow/core/framework/tensor_shape.h"
|
#include "tensorflow/core/framework/tensor_shape.h"
|
||||||
#include "tensorflow/core/framework/types.h"
|
#include "tensorflow/core/framework/types.h"
|
||||||
#include "tensorflow/core/kernels/bounds_check.h"
|
#include "tensorflow/core/kernels/bounds_check.h"
|
||||||
|
#include "tensorflow/core/lib/core/errors.h"
|
||||||
#include "tensorflow/core/lib/core/status.h"
|
#include "tensorflow/core/lib/core/status.h"
|
||||||
#include "tensorflow/core/platform/logging.h"
|
#include "tensorflow/core/platform/logging.h"
|
||||||
|
#include "tensorflow/core/platform/types.h"
|
||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
#include "tensorflow/core/common_runtime/gpu/gpu_event_mgr.h"
|
||||||
#include "tensorflow/core/platform/stream_executor.h"
|
#include "tensorflow/core/platform/stream_executor.h"
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
|
||||||
@ -37,41 +43,67 @@ namespace tensorflow {
|
|||||||
|
|
||||||
typedef Eigen::ThreadPoolDevice CPUDevice;
|
typedef Eigen::ThreadPoolDevice CPUDevice;
|
||||||
typedef Eigen::GpuDevice GPUDevice;
|
typedef Eigen::GpuDevice GPUDevice;
|
||||||
|
using Callback = std::function<void()>;
|
||||||
|
|
||||||
static inline void ParseAndCheckBoxSizes(OpKernelContext* context,
|
namespace {
|
||||||
const Tensor& boxes,
|
|
||||||
const Tensor& box_ind,
|
static inline Status ParseAndCheckBoxSizes(const Tensor& boxes,
|
||||||
int* num_boxes) {
|
const Tensor& box_index,
|
||||||
if (boxes.NumElements() == 0 && box_ind.NumElements() == 0) {
|
int* num_boxes) {
|
||||||
|
if (boxes.NumElements() == 0 && box_index.NumElements() == 0) {
|
||||||
*num_boxes = 0;
|
*num_boxes = 0;
|
||||||
return;
|
return Status::OK();
|
||||||
}
|
}
|
||||||
// The shape of 'boxes' is [num_boxes, 4].
|
// The shape of 'boxes' is [num_boxes, 4].
|
||||||
OP_REQUIRES(context, boxes.dims() == 2,
|
if (boxes.dims() != 2) {
|
||||||
errors::InvalidArgument("boxes must be 2-D",
|
return errors::InvalidArgument("boxes must be 2-D",
|
||||||
boxes.shape().DebugString()));
|
boxes.shape().DebugString());
|
||||||
|
}
|
||||||
*num_boxes = boxes.dim_size(0);
|
*num_boxes = boxes.dim_size(0);
|
||||||
OP_REQUIRES(context, boxes.dim_size(1) == 4,
|
if (boxes.dim_size(1) != 4) {
|
||||||
errors::InvalidArgument("boxes must have 4 columns"));
|
return errors::InvalidArgument("boxes must have 4 columns");
|
||||||
|
}
|
||||||
// The shape of 'box_ind' is [num_boxes].
|
// The shape of 'box_index' is [num_boxes].
|
||||||
OP_REQUIRES(context, box_ind.dims() == 1,
|
if (box_index.dims() != 1) {
|
||||||
errors::InvalidArgument("box_ind must be 1-D",
|
return errors::InvalidArgument("box_index must be 1-D",
|
||||||
box_ind.shape().DebugString()));
|
box_index.shape().DebugString());
|
||||||
OP_REQUIRES(context, box_ind.dim_size(0) == *num_boxes,
|
}
|
||||||
errors::InvalidArgument("box_ind has incompatible shape"));
|
if (box_index.dim_size(0) != *num_boxes) {
|
||||||
|
return errors::InvalidArgument("box_index has incompatible shape");
|
||||||
|
}
|
||||||
|
return Status::OK();
|
||||||
}
|
}
|
||||||
|
|
||||||
// Verifies that all values in box_ind are in [0, batch).
|
// Conditionally calls the compute callback if all values in box_index are in
|
||||||
|
// [0, batch_size) then calls done.
|
||||||
template <typename Device>
|
template <typename Device>
|
||||||
inline void CheckValidBoxInd(
|
inline void RunIfBoxIndexIsValid(
|
||||||
OpKernelContext* context,
|
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind_data, int batch);
|
int batch_size, Callback compute, Callback done);
|
||||||
|
|
||||||
|
// Specialization of CheckValidBoxIndex for a CPUDevice.
|
||||||
|
template <>
|
||||||
|
inline void RunIfBoxIndexIsValid<CPUDevice>(
|
||||||
|
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||||
|
int batch_size, Callback compute, Callback done) {
|
||||||
|
const int num_boxes = box_index.dimension(0);
|
||||||
|
for (int b = 0; b < num_boxes; ++b) {
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, FastBoundsCheck(box_index(b), batch_size),
|
||||||
|
errors::OutOfRange("box_index has values outside [0, batch_size)"),
|
||||||
|
done);
|
||||||
|
}
|
||||||
|
compute();
|
||||||
|
done();
|
||||||
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class CropAndResizeOp : public OpKernel {
|
class CropAndResizeOp : public AsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit CropAndResizeOp(OpKernelConstruction* context) : OpKernel(context) {
|
explicit CropAndResizeOp(OpKernelConstruction* context)
|
||||||
|
: AsyncOpKernel(context) {
|
||||||
string method;
|
string method;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
||||||
OP_REQUIRES(context, method == "bilinear",
|
OP_REQUIRES(context, method == "bilinear",
|
||||||
@ -80,69 +112,77 @@ class CropAndResizeOp : public OpKernel {
|
|||||||
&extrapolation_value_));
|
&extrapolation_value_));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||||
// The shape of 'image' is [batch, image_height, image_width, channels].
|
// The shape of 'image' is [batch_size, image_height, image_width,
|
||||||
|
// channels].
|
||||||
const Tensor& image = context->input(0);
|
const Tensor& image = context->input(0);
|
||||||
OP_REQUIRES(context, image.dims() == 4,
|
|
||||||
errors::InvalidArgument("input image must be 4-D",
|
|
||||||
image.shape().DebugString()));
|
|
||||||
|
|
||||||
const int batch = image.dim_size(0);
|
|
||||||
const int image_height = image.dim_size(1);
|
|
||||||
const int image_width = image.dim_size(2);
|
|
||||||
const int depth = image.dim_size(3);
|
|
||||||
OP_REQUIRES(context, image_height > 0 && image_width > 0,
|
|
||||||
errors::InvalidArgument("image dimensions must be positive"));
|
|
||||||
|
|
||||||
// The shape of 'boxes' is [num_boxes, 4].
|
// The shape of 'boxes' is [num_boxes, 4].
|
||||||
const Tensor& boxes = context->input(1);
|
const Tensor& boxes = context->input(1);
|
||||||
|
// The shape of 'box_index' is [num_boxes].
|
||||||
// The shape of 'box_ind' is [num_boxes].
|
const Tensor& box_index = context->input(2);
|
||||||
const Tensor& box_ind = context->input(2);
|
|
||||||
|
|
||||||
int num_boxes = 0;
|
|
||||||
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
|
|
||||||
|
|
||||||
// The shape of 'crop_size' is [2].
|
// The shape of 'crop_size' is [2].
|
||||||
const Tensor& crop_size = context->input(3);
|
const Tensor& crop_size = context->input(3);
|
||||||
|
|
||||||
OP_REQUIRES(context, crop_size.dims() == 1,
|
// Validate inputs dimensions.
|
||||||
errors::InvalidArgument("crop_size must be 1-D",
|
OP_REQUIRES_ASYNC(context, image.dims() == 4,
|
||||||
crop_size.shape().DebugString()));
|
errors::InvalidArgument("input image must be 4-D",
|
||||||
OP_REQUIRES(context, crop_size.dim_size(0) == 2,
|
image.shape().DebugString()),
|
||||||
errors::InvalidArgument("crop_size must have two elements",
|
done);
|
||||||
crop_size.shape().DebugString()));
|
const int batch_size = image.dim_size(0);
|
||||||
|
const int image_height = image.dim_size(1);
|
||||||
|
const int image_width = image.dim_size(2);
|
||||||
|
const int depth = image.dim_size(3);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, image_height > 0 && image_width > 0,
|
||||||
|
errors::InvalidArgument("image dimensions must be positive"), done);
|
||||||
|
int num_boxes = 0;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
|
||||||
|
|
||||||
|
OP_REQUIRES_ASYNC(context, crop_size.dims() == 1,
|
||||||
|
errors::InvalidArgument("crop_size must be 1-D",
|
||||||
|
crop_size.shape().DebugString()),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, crop_size.dim_size(0) == 2,
|
||||||
|
errors::InvalidArgument("crop_size must have two elements",
|
||||||
|
crop_size.shape().DebugString()),
|
||||||
|
done);
|
||||||
|
|
||||||
|
// Copy and validate crop sizes.
|
||||||
auto crop_size_vec = crop_size.vec<int32>();
|
auto crop_size_vec = crop_size.vec<int32>();
|
||||||
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
|
const int crop_height = internal::SubtleMustCopy(crop_size_vec(0));
|
||||||
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
|
const int crop_width = internal::SubtleMustCopy(crop_size_vec(1));
|
||||||
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
|
OP_REQUIRES_ASYNC(
|
||||||
errors::InvalidArgument("crop dimensions must be positive"));
|
context, crop_height > 0 && crop_width > 0,
|
||||||
|
errors::InvalidArgument("crop dimensions must be positive"), done);
|
||||||
|
|
||||||
// Allocate output tensor.
|
// Allocate output tensor.
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
context,
|
context,
|
||||||
context->allocate_output(
|
context->allocate_output(
|
||||||
0, TensorShape({num_boxes, crop_height, crop_width, depth}),
|
0, TensorShape({num_boxes, crop_height, crop_width, depth}),
|
||||||
&output));
|
&output),
|
||||||
|
done);
|
||||||
|
|
||||||
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
|
auto compute_callback = [this, context, output]() {
|
||||||
typename TTypes<float, 2>::ConstTensor boxes_data =
|
const Tensor& image = context->input(0);
|
||||||
boxes.tensor<float, 2>();
|
const Tensor& boxes = context->input(1);
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind_data =
|
const Tensor& box_index = context->input(2);
|
||||||
box_ind.tensor<int32, 1>();
|
const bool status = functor::CropAndResize<Device, T>()(
|
||||||
typename TTypes<float, 4>::Tensor crops_data = output->tensor<float, 4>();
|
context->eigen_device<Device>(), image.tensor<T, 4>(),
|
||||||
|
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
|
||||||
|
extrapolation_value_, output->tensor<float, 4>());
|
||||||
|
if (!status) {
|
||||||
|
context->SetStatus(
|
||||||
|
errors::Internal("Failed launch CropAndResizeKernel."));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
|
||||||
|
batch_size, std::move(compute_callback),
|
||||||
bool status = functor::CropAndResize<Device, T>()(
|
std::move(done));
|
||||||
context->eigen_device<Device>(), image_data, boxes_data, box_ind_data,
|
|
||||||
extrapolation_value_, crops_data);
|
|
||||||
if (!status) {
|
|
||||||
context->SetStatus(
|
|
||||||
errors::Internal("Failed launch CropAndResizeKernel."));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
private:
|
private:
|
||||||
@ -155,10 +195,10 @@ template <typename T>
|
|||||||
struct CropAndResize<CPUDevice, T> {
|
struct CropAndResize<CPUDevice, T> {
|
||||||
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
|
bool operator()(const CPUDevice& d, typename TTypes<T, 4>::ConstTensor image,
|
||||||
typename TTypes<float, 2>::ConstTensor boxes,
|
typename TTypes<float, 2>::ConstTensor boxes,
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
typename TTypes<int32, 1>::ConstTensor box_index,
|
||||||
float extrapolation_value,
|
float extrapolation_value,
|
||||||
typename TTypes<float, 4>::Tensor crops) {
|
typename TTypes<float, 4>::Tensor crops) {
|
||||||
const int batch = image.dimension(0);
|
const int batch_size = image.dimension(0);
|
||||||
const int image_height = image.dimension(1);
|
const int image_height = image.dimension(1);
|
||||||
const int image_width = image.dimension(2);
|
const int image_width = image.dimension(2);
|
||||||
|
|
||||||
@ -173,8 +213,8 @@ struct CropAndResize<CPUDevice, T> {
|
|||||||
const float y2 = boxes(b, 2);
|
const float y2 = boxes(b, 2);
|
||||||
const float x2 = boxes(b, 3);
|
const float x2 = boxes(b, 3);
|
||||||
|
|
||||||
const int32 b_in = box_ind(b);
|
const int32 b_in = box_index(b);
|
||||||
if (b_in < 0 || b_in >= batch) {
|
if (!FastBoundsCheck(b_in, batch_size)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -235,89 +275,94 @@ struct CropAndResize<CPUDevice, T> {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class CropAndResizeGradImageOp : public OpKernel {
|
class CropAndResizeGradImageOp : public AsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
|
explicit CropAndResizeGradImageOp(OpKernelConstruction* context)
|
||||||
: OpKernel(context) {
|
: AsyncOpKernel(context) {
|
||||||
string method;
|
string method;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
||||||
OP_REQUIRES(context, method == "bilinear",
|
OP_REQUIRES(context, method == "bilinear",
|
||||||
errors::InvalidArgument("method must be 'bilinear'", method));
|
errors::InvalidArgument("method must be 'bilinear'", method));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||||
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
||||||
const Tensor& grads = context->input(0);
|
const Tensor& grads = context->input(0);
|
||||||
|
|
||||||
OP_REQUIRES(context, grads.dims() == 4,
|
|
||||||
errors::InvalidArgument("grads image must be 4-D",
|
|
||||||
grads.shape().DebugString()));
|
|
||||||
const int crop_height = grads.dim_size(1);
|
|
||||||
const int crop_width = grads.dim_size(2);
|
|
||||||
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
|
|
||||||
errors::InvalidArgument("grads dimensions must be positive"));
|
|
||||||
|
|
||||||
// The shape of 'boxes' is [num_boxes, 4].
|
// The shape of 'boxes' is [num_boxes, 4].
|
||||||
const Tensor& boxes = context->input(1);
|
const Tensor& boxes = context->input(1);
|
||||||
|
// The shape of 'box_index' is [num_boxes].
|
||||||
// The shape of 'box_ind' is [num_boxes].
|
const Tensor& box_index = context->input(2);
|
||||||
const Tensor& box_ind = context->input(2);
|
|
||||||
|
|
||||||
int num_boxes = 0;
|
|
||||||
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
|
|
||||||
|
|
||||||
OP_REQUIRES(
|
|
||||||
context, grads.dim_size(0) == num_boxes,
|
|
||||||
errors::InvalidArgument("boxes and grads have incompatible shape"));
|
|
||||||
|
|
||||||
// The shape of 'image_size' is [4].
|
// The shape of 'image_size' is [4].
|
||||||
const Tensor& image_size = context->input(3);
|
const Tensor& image_size = context->input(3);
|
||||||
OP_REQUIRES(context, image_size.dims() == 1,
|
|
||||||
errors::InvalidArgument("image_size must be 1-D",
|
|
||||||
image_size.shape().DebugString()));
|
|
||||||
OP_REQUIRES(context, image_size.dim_size(0) == 4,
|
|
||||||
errors::InvalidArgument("image_size must have 4 elements",
|
|
||||||
image_size.shape().DebugString()));
|
|
||||||
|
|
||||||
|
// Validate input shapes.
|
||||||
|
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
|
||||||
|
errors::InvalidArgument("grads image must be 4-D",
|
||||||
|
grads.shape().DebugString()),
|
||||||
|
done);
|
||||||
|
const int crop_height = grads.dim_size(1);
|
||||||
|
const int crop_width = grads.dim_size(2);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, crop_height > 0 && crop_width > 0,
|
||||||
|
errors::InvalidArgument("grads dimensions must be positive"), done);
|
||||||
|
int num_boxes = 0;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, grads.dim_size(0) == num_boxes,
|
||||||
|
errors::InvalidArgument("boxes and grads have incompatible shape"),
|
||||||
|
done);
|
||||||
|
|
||||||
|
OP_REQUIRES_ASYNC(context, image_size.dims() == 1,
|
||||||
|
errors::InvalidArgument("image_size must be 1-D",
|
||||||
|
image_size.shape().DebugString()),
|
||||||
|
done);
|
||||||
|
OP_REQUIRES_ASYNC(context, image_size.dim_size(0) == 4,
|
||||||
|
errors::InvalidArgument("image_size must have 4 elements",
|
||||||
|
image_size.shape().DebugString()),
|
||||||
|
done);
|
||||||
auto image_size_vec = image_size.vec<int32>();
|
auto image_size_vec = image_size.vec<int32>();
|
||||||
const int batch = internal::SubtleMustCopy(image_size_vec(0));
|
const int batch_size = internal::SubtleMustCopy(image_size_vec(0));
|
||||||
const int image_height = internal::SubtleMustCopy(image_size_vec(1));
|
const int image_height = internal::SubtleMustCopy(image_size_vec(1));
|
||||||
const int image_width = internal::SubtleMustCopy(image_size_vec(2));
|
const int image_width = internal::SubtleMustCopy(image_size_vec(2));
|
||||||
const int depth = internal::SubtleMustCopy(image_size_vec(3));
|
const int depth = internal::SubtleMustCopy(image_size_vec(3));
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
OP_REQUIRES(context, image_height > 0 && image_width > 0,
|
context, image_height > 0 && image_width > 0,
|
||||||
errors::InvalidArgument("image dimensions must be positive"));
|
errors::InvalidArgument("image dimensions must be positive"), done);
|
||||||
OP_REQUIRES(
|
OP_REQUIRES_ASYNC(
|
||||||
context, grads.dim_size(3) == depth,
|
context, grads.dim_size(3) == depth,
|
||||||
errors::InvalidArgument("image_size and grads are incompatible"));
|
errors::InvalidArgument("image_size and grads are incompatible"), done);
|
||||||
|
|
||||||
// Allocate output tensor.
|
// Allocate output tensor.
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
context, context->allocate_output(
|
context,
|
||||||
0, TensorShape({batch, image_height, image_width, depth}),
|
context->allocate_output(
|
||||||
&output));
|
0, TensorShape({batch_size, image_height, image_width, depth}),
|
||||||
|
&output),
|
||||||
|
done);
|
||||||
|
|
||||||
typename TTypes<float, 4>::ConstTensor grads_data =
|
auto compute_callback = [context, output]() {
|
||||||
grads.tensor<float, 4>();
|
const Tensor& grads = context->input(0);
|
||||||
typename TTypes<float, 2>::ConstTensor boxes_data =
|
const Tensor& boxes = context->input(1);
|
||||||
boxes.tensor<float, 2>();
|
const Tensor& box_index = context->input(2);
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind_data =
|
const bool status = functor::CropAndResizeBackpropImage<Device, T>()(
|
||||||
box_ind.tensor<int32, 1>();
|
context->eigen_device<Device>(), grads.tensor<float, 4>(),
|
||||||
typename TTypes<T, 4>::Tensor output_data = output->tensor<T, 4>();
|
boxes.tensor<float, 2>(), box_index.tensor<int32, 1>(),
|
||||||
|
output->tensor<T, 4>());
|
||||||
|
if (!status) {
|
||||||
|
context->SetStatus(errors::Internal(
|
||||||
|
"Failed launch CropAndResizeBackpropImage kernel."));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
|
||||||
|
batch_size, std::move(compute_callback),
|
||||||
bool status = functor::CropAndResizeBackpropImage<Device, T>()(
|
std::move(done));
|
||||||
context->eigen_device<Device>(), grads_data, boxes_data, box_ind_data,
|
|
||||||
output_data);
|
|
||||||
if (!status) {
|
|
||||||
context->SetStatus(
|
|
||||||
errors::Internal("Failed launch CropAndResizeBackpropImageKernel."));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -328,9 +373,9 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
|||||||
bool operator()(const CPUDevice& d,
|
bool operator()(const CPUDevice& d,
|
||||||
typename TTypes<float, 4>::ConstTensor grads,
|
typename TTypes<float, 4>::ConstTensor grads,
|
||||||
typename TTypes<float, 2>::ConstTensor boxes,
|
typename TTypes<float, 2>::ConstTensor boxes,
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
typename TTypes<int32, 1>::ConstTensor box_index,
|
||||||
typename TTypes<T, 4>::Tensor grads_image) {
|
typename TTypes<T, 4>::Tensor grads_image) {
|
||||||
const int batch = grads_image.dimension(0);
|
const int batch_size = grads_image.dimension(0);
|
||||||
const int image_height = grads_image.dimension(1);
|
const int image_height = grads_image.dimension(1);
|
||||||
const int image_width = grads_image.dimension(2);
|
const int image_width = grads_image.dimension(2);
|
||||||
|
|
||||||
@ -347,8 +392,8 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
|||||||
const float y2 = boxes(b, 2);
|
const float y2 = boxes(b, 2);
|
||||||
const float x2 = boxes(b, 3);
|
const float x2 = boxes(b, 3);
|
||||||
|
|
||||||
const int32 b_in = box_ind(b);
|
const int32 b_in = box_index(b);
|
||||||
if (b_in < 0 || b_in >= batch) {
|
if (!FastBoundsCheck(b_in, batch_size)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -399,83 +444,90 @@ struct CropAndResizeBackpropImage<CPUDevice, T> {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
template <typename Device, typename T>
|
template <typename Device, typename T>
|
||||||
class CropAndResizeGradBoxesOp : public OpKernel {
|
class CropAndResizeGradBoxesOp : public AsyncOpKernel {
|
||||||
public:
|
public:
|
||||||
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
|
explicit CropAndResizeGradBoxesOp(OpKernelConstruction* context)
|
||||||
: OpKernel(context) {
|
: AsyncOpKernel(context) {
|
||||||
string method;
|
string method;
|
||||||
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
OP_REQUIRES_OK(context, context->GetAttr("method", &method));
|
||||||
OP_REQUIRES(context, method == "bilinear",
|
OP_REQUIRES(context, method == "bilinear",
|
||||||
errors::InvalidArgument("method must be 'bilinear'", method));
|
errors::InvalidArgument("method must be 'bilinear'", method));
|
||||||
}
|
}
|
||||||
|
|
||||||
void Compute(OpKernelContext* context) override {
|
void ComputeAsync(OpKernelContext* context, DoneCallback done) override {
|
||||||
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
// The shape of 'grads' is [num_boxes, crop_height, crop_width, depth].
|
||||||
const Tensor& grads = context->input(0);
|
const Tensor& grads = context->input(0);
|
||||||
|
// The shape of 'boxes' is [num_boxes, 4].
|
||||||
|
const Tensor& boxes = context->input(2);
|
||||||
|
// The shape of 'box_index' is [num_boxes].
|
||||||
|
const Tensor& box_index = context->input(3);
|
||||||
|
// The shape of 'image' is [batch_size, image_height, image_width, depth].
|
||||||
|
const Tensor& image = context->input(1);
|
||||||
|
|
||||||
OP_REQUIRES(context, grads.dims() == 4,
|
// Validate input shapes.
|
||||||
errors::InvalidArgument("grads image must be 4-D",
|
OP_REQUIRES_ASYNC(context, grads.dims() == 4,
|
||||||
grads.shape().DebugString()));
|
errors::InvalidArgument("grads image must be 4-D",
|
||||||
|
grads.shape().DebugString()),
|
||||||
|
done);
|
||||||
const int crop_height = grads.dim_size(1);
|
const int crop_height = grads.dim_size(1);
|
||||||
const int crop_width = grads.dim_size(2);
|
const int crop_width = grads.dim_size(2);
|
||||||
const int depth = grads.dim_size(3);
|
const int depth = grads.dim_size(3);
|
||||||
OP_REQUIRES(context, crop_height > 0 && crop_width > 0,
|
OP_REQUIRES_ASYNC(
|
||||||
errors::InvalidArgument("grads dimensions must be positive"));
|
context, crop_height > 0 && crop_width > 0,
|
||||||
|
errors::InvalidArgument("grads dimensions must be positive"), done);
|
||||||
|
|
||||||
// The shape of 'image' is [batch, image_height, image_width, depth].
|
OP_REQUIRES_ASYNC(context, image.dims() == 4,
|
||||||
const Tensor& image = context->input(1);
|
errors::InvalidArgument("input image must be 4-D",
|
||||||
OP_REQUIRES(context, image.dims() == 4,
|
image.shape().DebugString()),
|
||||||
errors::InvalidArgument("input image must be 4-D",
|
done);
|
||||||
image.shape().DebugString()));
|
const int batch_size = image.dim_size(0);
|
||||||
|
|
||||||
const int batch = image.dim_size(0);
|
|
||||||
const int image_height = image.dim_size(1);
|
const int image_height = image.dim_size(1);
|
||||||
const int image_width = image.dim_size(2);
|
const int image_width = image.dim_size(2);
|
||||||
OP_REQUIRES(context, image_height > 0 && image_width > 0,
|
OP_REQUIRES_ASYNC(
|
||||||
errors::InvalidArgument("image dimensions must be positive"));
|
context, image_height > 0 && image_width > 0,
|
||||||
OP_REQUIRES(context, image.dim_size(3) == depth,
|
errors::InvalidArgument("image dimensions must be positive"), done);
|
||||||
errors::InvalidArgument("image, grads depth differ"));
|
OP_REQUIRES_ASYNC(context, image.dim_size(3) == depth,
|
||||||
|
errors::InvalidArgument("image, grads depth differ"),
|
||||||
// The shape of 'boxes' is [num_boxes, 4].
|
done);
|
||||||
const Tensor& boxes = context->input(2);
|
|
||||||
|
|
||||||
// The shape of 'box_ind' is [num_boxes].
|
|
||||||
const Tensor& box_ind = context->input(3);
|
|
||||||
|
|
||||||
int num_boxes = 0;
|
int num_boxes = 0;
|
||||||
ParseAndCheckBoxSizes(context, boxes, box_ind, &num_boxes);
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context, ParseAndCheckBoxSizes(boxes, box_index, &num_boxes), done);
|
||||||
|
|
||||||
OP_REQUIRES(
|
OP_REQUIRES_ASYNC(
|
||||||
context, grads.dim_size(0) == num_boxes,
|
context, grads.dim_size(0) == num_boxes,
|
||||||
errors::InvalidArgument("boxes and grads have incompatible shape"));
|
errors::InvalidArgument("boxes and grads have incompatible shape"),
|
||||||
|
done);
|
||||||
|
|
||||||
// Allocate output tensor.
|
// Allocate output tensor.
|
||||||
Tensor* output = nullptr;
|
Tensor* output = nullptr;
|
||||||
OP_REQUIRES_OK(context, context->allocate_output(
|
OP_REQUIRES_OK_ASYNC(
|
||||||
0, TensorShape({num_boxes, 4}), &output));
|
context,
|
||||||
|
context->allocate_output(0, TensorShape({num_boxes, 4}), &output),
|
||||||
|
done);
|
||||||
|
|
||||||
typename TTypes<float, 4>::ConstTensor grads_data =
|
auto compute_callback = [context, output]() {
|
||||||
grads.tensor<float, 4>();
|
const Tensor& grads = context->input(0);
|
||||||
typename TTypes<T, 4>::ConstTensor image_data = image.tensor<T, 4>();
|
const Tensor& image = context->input(1);
|
||||||
typename TTypes<float, 2>::ConstTensor boxes_data =
|
const Tensor& boxes = context->input(2);
|
||||||
boxes.tensor<float, 2>();
|
const Tensor& box_index = context->input(3);
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind_data =
|
const bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
|
||||||
box_ind.tensor<int32, 1>();
|
context->eigen_device<Device>(), grads.tensor<float, 4>(),
|
||||||
typename TTypes<float, 2>::Tensor output_data = output->tensor<float, 2>();
|
image.tensor<T, 4>(), boxes.tensor<float, 2>(),
|
||||||
|
box_index.tensor<int32, 1>(), output->tensor<float, 2>());
|
||||||
|
if (!status) {
|
||||||
|
context->SetStatus(errors::Internal(
|
||||||
|
"Failed launch CropAndResizeBackpropBoxes kernel."));
|
||||||
|
}
|
||||||
|
};
|
||||||
|
|
||||||
CheckValidBoxInd<Device>(context, box_ind_data, batch);
|
RunIfBoxIndexIsValid<Device>(context, box_index.tensor<int32, 1>(),
|
||||||
|
batch_size, std::move(compute_callback),
|
||||||
bool status = functor::CropAndResizeBackpropBoxes<Device, T>()(
|
std::move(done));
|
||||||
context->eigen_device<Device>(), grads_data, image_data, boxes_data,
|
|
||||||
box_ind_data, output_data);
|
|
||||||
if (!status) {
|
|
||||||
context->SetStatus(
|
|
||||||
errors::Internal("Failed launch CropAndResizeBackpropBoxesKernel."));
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
@ -487,9 +539,9 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
|||||||
typename TTypes<float, 4>::ConstTensor grads,
|
typename TTypes<float, 4>::ConstTensor grads,
|
||||||
typename TTypes<T, 4>::ConstTensor image,
|
typename TTypes<T, 4>::ConstTensor image,
|
||||||
typename TTypes<float, 2>::ConstTensor boxes,
|
typename TTypes<float, 2>::ConstTensor boxes,
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind,
|
typename TTypes<int32, 1>::ConstTensor box_index,
|
||||||
typename TTypes<float, 2>::Tensor grads_boxes) {
|
typename TTypes<float, 2>::Tensor grads_boxes) {
|
||||||
const int batch = image.dimension(0);
|
const int batch_size = image.dimension(0);
|
||||||
const int image_height = image.dimension(1);
|
const int image_height = image.dimension(1);
|
||||||
const int image_width = image.dimension(2);
|
const int image_width = image.dimension(2);
|
||||||
|
|
||||||
@ -506,8 +558,8 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
|||||||
const float y2 = boxes(b, 2);
|
const float y2 = boxes(b, 2);
|
||||||
const float x2 = boxes(b, 3);
|
const float x2 = boxes(b, 3);
|
||||||
|
|
||||||
const int32 b_in = box_ind(b);
|
const int32 b_in = box_index(b);
|
||||||
if (b_in < 0 || b_in >= batch) {
|
if (!FastBoundsCheck(b_in, batch_size)) {
|
||||||
continue;
|
continue;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -589,30 +641,19 @@ struct CropAndResizeBackpropBoxes<CPUDevice, T> {
|
|||||||
return true;
|
return true;
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
// Specialization of CheckValidBoxInd for a CPUDevice.
|
#define REGISTER_KERNEL(T) \
|
||||||
template <>
|
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
||||||
inline void CheckValidBoxInd<CPUDevice>(
|
.Device(DEVICE_CPU) \
|
||||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
|
.TypeConstraint<T>("T") \
|
||||||
int batch) {
|
.HostMemory("crop_size"), \
|
||||||
const int num_boxes = box_ind.dimension(0);
|
CropAndResizeOp<CPUDevice, T>); \
|
||||||
for (int b = 0; b < num_boxes; ++b) {
|
\
|
||||||
OP_REQUIRES(context, box_ind(b) >= 0 && box_ind(b) < batch,
|
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
|
||||||
errors::OutOfRange("box_ind has values outside [0, batch)"));
|
.Device(DEVICE_CPU) \
|
||||||
}
|
.TypeConstraint<T>("T"), \
|
||||||
}
|
|
||||||
|
|
||||||
#define REGISTER_KERNEL(T) \
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
|
||||||
.Device(DEVICE_CPU) \
|
|
||||||
.TypeConstraint<T>("T") \
|
|
||||||
.HostMemory("crop_size"), \
|
|
||||||
CropAndResizeOp<CPUDevice, T>); \
|
|
||||||
\
|
|
||||||
REGISTER_KERNEL_BUILDER(Name("CropAndResizeGradBoxes") \
|
|
||||||
.Device(DEVICE_CPU) \
|
|
||||||
.TypeConstraint<T>("T"), \
|
|
||||||
CropAndResizeGradBoxesOp<CPUDevice, T>);
|
CropAndResizeGradBoxesOp<CPUDevice, T>);
|
||||||
|
|
||||||
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
|
TF_CALL_REAL_NUMBER_TYPES(REGISTER_KERNEL);
|
||||||
@ -634,50 +675,86 @@ TF_CALL_double(REGISTER_KERNEL);
|
|||||||
|
|
||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
|
|
||||||
// Forward declaration of the CheckValidBoxIndHelper specialization for GPU.
|
// Forward declaration of the CheckValidBoxIndexHelper specialization for GPU.
|
||||||
namespace functor {
|
namespace functor {
|
||||||
template <>
|
template <>
|
||||||
void CheckValidBoxIndHelper<GPUDevice>::operator()(
|
void CheckValidBoxIndexHelper<GPUDevice>::operator()(
|
||||||
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_ind,
|
const GPUDevice& d, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||||
int batch, typename TTypes<bool, 0>::Tensor isvalid);
|
int batch_size, typename TTypes<bool, 0>::Tensor isvalid);
|
||||||
extern template struct CheckValidBoxIndHelper<GPUDevice>;
|
extern template struct CheckValidBoxIndexHelper<GPUDevice>;
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
|
|
||||||
// Specialization of CheckValidBoxInd for a GPUDevice.
|
namespace {
|
||||||
|
|
||||||
|
// Specialization of CheckValidBoxIndex for a GPUDevice.
|
||||||
template <>
|
template <>
|
||||||
inline void CheckValidBoxInd<GPUDevice>(
|
inline void RunIfBoxIndexIsValid<GPUDevice>(
|
||||||
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_ind,
|
OpKernelContext* context, typename TTypes<int32, 1>::ConstTensor box_index,
|
||||||
int batch) {
|
int batch_size, Callback compute, Callback done) {
|
||||||
const int num_boxes = box_ind.dimension(0);
|
const int num_boxes = box_index.dimension(0);
|
||||||
if (num_boxes == 0) {
|
if (num_boxes == 0) {
|
||||||
|
compute();
|
||||||
|
done();
|
||||||
return;
|
return;
|
||||||
}
|
}
|
||||||
Tensor isvalid_tensor;
|
|
||||||
OP_REQUIRES_OK(context,
|
|
||||||
context->allocate_temp(DataTypeToEnum<bool>::value,
|
|
||||||
TensorShape({}), &isvalid_tensor));
|
|
||||||
|
|
||||||
typename TTypes<bool, 0>::Tensor isvalid = isvalid_tensor.tensor<bool, 0>();
|
Tensor isvalid_dev_tensor;
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context,
|
||||||
|
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
|
||||||
|
&isvalid_dev_tensor),
|
||||||
|
done);
|
||||||
|
typename TTypes<bool, 0>::Tensor isvalid_dev =
|
||||||
|
isvalid_dev_tensor.tensor<bool, 0>();
|
||||||
|
|
||||||
functor::CheckValidBoxIndHelper<GPUDevice>()(
|
// Run the actual box check on the device.
|
||||||
context->eigen_device<GPUDevice>(), box_ind, batch, isvalid);
|
functor::CheckValidBoxIndexHelper<GPUDevice>()(
|
||||||
|
context->eigen_device<GPUDevice>(), box_index, batch_size, isvalid_dev);
|
||||||
|
|
||||||
|
// Copy the result back to the host.
|
||||||
auto* stream = context->op_device_context()->stream();
|
auto* stream = context->op_device_context()->stream();
|
||||||
OP_REQUIRES(context, stream, errors::Internal("No GPU stream available."));
|
OP_REQUIRES_ASYNC(context, stream,
|
||||||
|
errors::Internal("No GPU stream available."), done);
|
||||||
|
Tensor isvalid_host_tensor;
|
||||||
|
// Use pinned host memory on the host to avoid unnecessary
|
||||||
|
// synchronization.
|
||||||
|
AllocatorAttributes alloc_attr;
|
||||||
|
alloc_attr.set_on_host(true);
|
||||||
|
alloc_attr.set_gpu_compatible(true);
|
||||||
|
OP_REQUIRES_OK_ASYNC(
|
||||||
|
context,
|
||||||
|
context->allocate_temp(DataTypeToEnum<bool>::value, TensorShape({}),
|
||||||
|
&isvalid_host_tensor, alloc_attr),
|
||||||
|
done);
|
||||||
|
perftools::gputools::DeviceMemoryBase wrapped(isvalid_dev.data(),
|
||||||
|
sizeof(bool));
|
||||||
|
const bool status =
|
||||||
|
stream
|
||||||
|
->ThenMemcpy(
|
||||||
|
isvalid_host_tensor.scalar<bool>().data() /* destination */,
|
||||||
|
wrapped /* source */, sizeof(bool))
|
||||||
|
.ok();
|
||||||
|
OP_REQUIRES_ASYNC(
|
||||||
|
context, status,
|
||||||
|
errors::Internal("Failed to launch copy of isvalid from device to host."),
|
||||||
|
done);
|
||||||
|
|
||||||
bool isvalid_host = false;
|
auto wrapped_callback = [context, isvalid_host_tensor, compute, done]() {
|
||||||
perftools::gputools::DeviceMemoryBase isvalid_gpu(isvalid.data(),
|
const bool isvalid = isvalid_host_tensor.scalar<bool>()();
|
||||||
sizeof(bool));
|
OP_REQUIRES_ASYNC(
|
||||||
stream->ThenMemcpy(&isvalid_host, isvalid_gpu, sizeof(bool));
|
context, isvalid,
|
||||||
stream->BlockHostUntilDone();
|
errors::OutOfRange("box_index has values outside [0, batch_size)"),
|
||||||
|
done);
|
||||||
|
compute();
|
||||||
|
done();
|
||||||
|
};
|
||||||
|
|
||||||
OP_REQUIRES(context, stream->ok(),
|
context->device()->tensorflow_gpu_device_info()->event_mgr->ThenExecute(
|
||||||
errors::Internal("cudaMemcpy from device to host failed"));
|
stream, wrapped_callback);
|
||||||
|
|
||||||
OP_REQUIRES(context, isvalid_host,
|
|
||||||
errors::OutOfRange("box_ind has values outside [0, batch)"));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
|
} // namespace
|
||||||
|
|
||||||
#define REGISTER_KERNEL(T) \
|
#define REGISTER_KERNEL(T) \
|
||||||
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
REGISTER_KERNEL_BUILDER(Name("CropAndResize") \
|
||||||
.Device(DEVICE_GPU) \
|
.Device(DEVICE_GPU) \
|
||||||
|
@ -53,12 +53,12 @@ struct CropAndResizeBackpropBoxes {
|
|||||||
};
|
};
|
||||||
|
|
||||||
template <typename Device>
|
template <typename Device>
|
||||||
struct CheckValidBoxIndHelper {
|
struct CheckValidBoxIndexHelper {
|
||||||
// Checks if all values in box_ind are in [0, batch).
|
// Checks if all values in box_index are in [0, batch).
|
||||||
void operator()(const Device& d,
|
void operator()(const Device& d,
|
||||||
typename TTypes<int32, 1>::ConstTensor box_ind, int batch,
|
typename TTypes<int32, 1>::ConstTensor box_index, int batch,
|
||||||
typename TTypes<bool, 0>::Tensor isvalid) {
|
typename TTypes<bool, 0>::Tensor isvalid) {
|
||||||
isvalid.device(d) = ((box_ind >= 0) && (box_ind < batch)).all();
|
isvalid.device(d) = ((box_index >= 0) && (box_index < batch)).all();
|
||||||
}
|
}
|
||||||
};
|
};
|
||||||
|
|
||||||
|
@ -440,7 +440,7 @@ TF_CALL_GPU_NUMBER_TYPES(DEFINE_GPU_SPECS);
|
|||||||
|
|
||||||
#undef DEFINE_GPU_SPECS
|
#undef DEFINE_GPU_SPECS
|
||||||
|
|
||||||
template struct CheckValidBoxIndHelper<GPUDevice>;
|
template struct CheckValidBoxIndexHelper<GPUDevice>;
|
||||||
|
|
||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -251,7 +251,7 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndexShape) {
|
|||||||
Status s = RunOpKernel();
|
Status s = RunOpKernel();
|
||||||
ASSERT_FALSE(s.ok());
|
ASSERT_FALSE(s.ok());
|
||||||
EXPECT_TRUE(
|
EXPECT_TRUE(
|
||||||
StringPiece(s.ToString()).contains("box_ind has incompatible shape"))
|
StringPiece(s.ToString()).contains("box_index has incompatible shape"))
|
||||||
<< s;
|
<< s;
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -264,8 +264,10 @@ TEST_F(CropAndResizeOpTest, TestInvalidBoxIndex) {
|
|||||||
Status s = RunOpKernel();
|
Status s = RunOpKernel();
|
||||||
ASSERT_FALSE(s.ok());
|
ASSERT_FALSE(s.ok());
|
||||||
EXPECT_TRUE(StringPiece(s.ToString())
|
EXPECT_TRUE(StringPiece(s.ToString())
|
||||||
.contains("box_ind has values outside [0, batch)"))
|
.contains("box_index has values outside [0, batch_size)"))
|
||||||
<< s;
|
<< s;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// TODO(zhengxq, rmlarsen): Add a benchmark.
|
||||||
|
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -20,4 +20,4 @@ REGISTER2(BinaryOp, CPU, "Atan2", functor::atan2, float, double);
|
|||||||
#if GOOGLE_CUDA
|
#if GOOGLE_CUDA
|
||||||
REGISTER2(BinaryOp, GPU, "Atan2", functor::atan2, float, double);
|
REGISTER2(BinaryOp, GPU, "Atan2", functor::atan2, float, double);
|
||||||
#endif
|
#endif
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
@ -23,4 +23,4 @@ DEFINE_BINARY2(atan2, float, double);
|
|||||||
} // namespace functor
|
} // namespace functor
|
||||||
} // namespace tensorflow
|
} // namespace tensorflow
|
||||||
|
|
||||||
#endif // GOOGLE_CUDA
|
#endif // GOOGLE_CUDA
|
||||||
|
@ -155,7 +155,8 @@ void LinearAlgebraOp<Scalar>::AnalyzeInputs(OpKernelContext* context,
|
|||||||
const int col_dimension = input_rank - 1;
|
const int col_dimension = input_rank - 1;
|
||||||
const int64 num_rows = in.dim_size(row_dimension);
|
const int64 num_rows = in.dim_size(row_dimension);
|
||||||
const int64 num_cols = in.dim_size(col_dimension);
|
const int64 num_cols = in.dim_size(col_dimension);
|
||||||
input_matrix_shapes->emplace_back(std::initializer_list<int64>({num_rows, num_cols}));
|
input_matrix_shapes->emplace_back(
|
||||||
|
std::initializer_list<int64>({num_rows, num_cols}));
|
||||||
inputs->emplace_back(&in);
|
inputs->emplace_back(&in);
|
||||||
}
|
}
|
||||||
// Have the derived class validate that the inputs are as expected.
|
// Have the derived class validate that the inputs are as expected.
|
||||||
@ -233,8 +234,7 @@ void LinearAlgebraOp<Scalar>::ComputeTensorSlice(
|
|||||||
matrix_inputs.emplace_back(
|
matrix_inputs.emplace_back(
|
||||||
inputs[i]->flat<Scalar>().data() +
|
inputs[i]->flat<Scalar>().data() +
|
||||||
matrix_index * input_matrix_shapes[i].num_elements(),
|
matrix_index * input_matrix_shapes[i].num_elements(),
|
||||||
input_matrix_shapes[i].dim_size(0),
|
input_matrix_shapes[i].dim_size(0), input_matrix_shapes[i].dim_size(1));
|
||||||
input_matrix_shapes[i].dim_size(1));
|
|
||||||
}
|
}
|
||||||
|
|
||||||
MatrixMaps matrix_outputs;
|
MatrixMaps matrix_outputs;
|
||||||
|
@ -1716,6 +1716,31 @@ op {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "Atan2"
|
||||||
|
input_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "z"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "AudioSpectrogram"
|
name: "AudioSpectrogram"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -1904,6 +1904,33 @@ op {
|
|||||||
}
|
}
|
||||||
summary: "Computes atan of x element-wise."
|
summary: "Computes atan of x element-wise."
|
||||||
}
|
}
|
||||||
|
op {
|
||||||
|
name: "Atan2"
|
||||||
|
input_arg {
|
||||||
|
name: "y"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
input_arg {
|
||||||
|
name: "x"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
output_arg {
|
||||||
|
name: "z"
|
||||||
|
type_attr: "T"
|
||||||
|
}
|
||||||
|
attr {
|
||||||
|
name: "T"
|
||||||
|
type: "type"
|
||||||
|
allowed_values {
|
||||||
|
list {
|
||||||
|
type: DT_FLOAT
|
||||||
|
type: DT_DOUBLE
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
summary: "Computes arctangent of `y/x` element-wise, respecting signs of the arguments."
|
||||||
|
description: "This is the angle \\( \\theta \\in [-\\pi, \\pi] \\) such that\n\\[ x = r \\cos(\\theta) \\]\nand\n\\[ y = r \\sin(\\theta) \\]\nwhere \\(r = \\sqrt(x^2 + y^2) \\)."
|
||||||
|
}
|
||||||
op {
|
op {
|
||||||
name: "AudioSpectrogram"
|
name: "AudioSpectrogram"
|
||||||
input_arg {
|
input_arg {
|
||||||
|
@ -1,33 +0,0 @@
|
|||||||
# Random variable transformations (contrib)
|
|
||||||
[TOC]
|
|
||||||
|
|
||||||
Bijector Ops.
|
|
||||||
|
|
||||||
An API for invertible, differentiable transformations of random variables.
|
|
||||||
|
|
||||||
## Background
|
|
||||||
|
|
||||||
Differentiable, bijective transformations of continuous random variables alter
|
|
||||||
the calculations made in the cumulative/probability distribution functions and
|
|
||||||
sample function. This module provides a standard interface for making these
|
|
||||||
manipulations.
|
|
||||||
|
|
||||||
For more details and examples, see the `Bijector` docstring.
|
|
||||||
|
|
||||||
To apply a `Bijector`, use `distributions.TransformedDistribution`.
|
|
||||||
|
|
||||||
## Bijectors
|
|
||||||
|
|
||||||
* @{tf.contrib.distributions.bijector.Affine}
|
|
||||||
* @{tf.contrib.distributions.bijector.AffineLinearOperator}
|
|
||||||
* @{tf.contrib.distributions.bijector.Bijector}
|
|
||||||
* @{tf.contrib.distributions.bijector.Chain}
|
|
||||||
* @{tf.contrib.distributions.bijector.CholeskyOuterProduct}
|
|
||||||
* @{tf.contrib.distributions.bijector.Exp}
|
|
||||||
* @{tf.contrib.distributions.bijector.Identity}
|
|
||||||
* @{tf.contrib.distributions.bijector.Inline}
|
|
||||||
* @{tf.contrib.distributions.bijector.Invert}
|
|
||||||
* @{tf.contrib.distributions.bijector.PowerTransform}
|
|
||||||
* @{tf.contrib.distributions.bijector.SigmoidCentered}
|
|
||||||
* @{tf.contrib.distributions.bijector.SoftmaxCentered}
|
|
||||||
* @{tf.contrib.distributions.bijector.Softplus}
|
|
@ -0,0 +1,33 @@
|
|||||||
|
# Random variable transformations (contrib)
|
||||||
|
[TOC]
|
||||||
|
|
||||||
|
Bijector Ops.
|
||||||
|
|
||||||
|
An API for invertible, differentiable transformations of random variables.
|
||||||
|
|
||||||
|
## Background
|
||||||
|
|
||||||
|
Differentiable, bijective transformations of continuous random variables alter
|
||||||
|
the calculations made in the cumulative/probability distribution functions and
|
||||||
|
sample function. This module provides a standard interface for making these
|
||||||
|
manipulations.
|
||||||
|
|
||||||
|
For more details and examples, see the `Bijector` docstring.
|
||||||
|
|
||||||
|
To apply a `Bijector`, use `distributions.TransformedDistribution`.
|
||||||
|
|
||||||
|
## Bijectors
|
||||||
|
|
||||||
|
* @{tf.contrib.distributions.bijectors.Affine}
|
||||||
|
* @{tf.contrib.distributions.bijectors.AffineLinearOperator}
|
||||||
|
* @{tf.contrib.distributions.bijectors.Bijector}
|
||||||
|
* @{tf.contrib.distributions.bijectors.Chain}
|
||||||
|
* @{tf.contrib.distributions.bijectors.CholeskyOuterProduct}
|
||||||
|
* @{tf.contrib.distributions.bijectors.Exp}
|
||||||
|
* @{tf.contrib.distributions.bijectors.Identity}
|
||||||
|
* @{tf.contrib.distributions.bijectors.Inline}
|
||||||
|
* @{tf.contrib.distributions.bijectors.Invert}
|
||||||
|
* @{tf.contrib.distributions.bijectors.PowerTransform}
|
||||||
|
* @{tf.contrib.distributions.bijectors.SigmoidCentered}
|
||||||
|
* @{tf.contrib.distributions.bijectors.SoftmaxCentered}
|
||||||
|
* @{tf.contrib.distributions.bijectors.Softplus}
|
@ -76,7 +76,7 @@ representing the posterior or posterior predictive.
|
|||||||
|
|
||||||
## Kullback-Leibler Divergence
|
## Kullback-Leibler Divergence
|
||||||
|
|
||||||
* @{tf.contrib.distributions.kl}
|
* @{tf.contrib.distributions.kl_divergence}
|
||||||
* @{tf.contrib.distributions.RegisterKL}
|
* @{tf.contrib.distributions.RegisterKL}
|
||||||
|
|
||||||
## Utilities
|
## Utilities
|
||||||
|
@ -40,7 +40,7 @@
|
|||||||
* [Losses (contrib)](contrib.losses.md)
|
* [Losses (contrib)](contrib.losses.md)
|
||||||
* [Metrics (contrib)](contrib.metrics.md)
|
* [Metrics (contrib)](contrib.metrics.md)
|
||||||
* [Optimization (contrib)](contrib.opt.md)
|
* [Optimization (contrib)](contrib.opt.md)
|
||||||
* [Random variable transformations (contrib)](contrib.distributions.bijector.md)
|
* [Random variable transformations (contrib)](contrib.distributions.bijectors.md)
|
||||||
* [RNN and Cells (contrib)](contrib.rnn.md)
|
* [RNN and Cells (contrib)](contrib.rnn.md)
|
||||||
* [Seq2seq Library (contrib)](contrib.seq2seq.md)
|
* [Seq2seq Library (contrib)](contrib.seq2seq.md)
|
||||||
* [Staging (contrib)](contrib.staging.md)
|
* [Staging (contrib)](contrib.staging.md)
|
||||||
|
@ -80,10 +80,12 @@ section.
|
|||||||
* **OS:** Ubuntu 16.04 LTS with tests run via Docker
|
* **OS:** Ubuntu 16.04 LTS with tests run via Docker
|
||||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||||
* **TensorFlow GitHub hash:** b1e174e
|
* **TensorFlow GitHub hash:** b1e174e
|
||||||
|
* **Benchmark GitHub hash:** 9165a70
|
||||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||||
//tensorflow/tools/pip_package:build_pip_package`
|
//tensorflow/tools/pip_package:build_pip_package`
|
||||||
* **Disk:** Local SSD
|
* **Disk:** Local SSD
|
||||||
* **DataSet:** ImageNet
|
* **DataSet:** ImageNet
|
||||||
|
* **Test Date:** May 2017
|
||||||
|
|
||||||
Batch size and optimizer used for each model are listed in the table below. In
|
Batch size and optimizer used for each model are listed in the table below. In
|
||||||
addition to the batch sizes listed in the table, InceptionV3, ResNet-50,
|
addition to the batch sizes listed in the table, InceptionV3, ResNet-50,
|
||||||
@ -120,19 +122,19 @@ VGG16 | replicated (with NCCL) | n/a
|
|||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||||
---- | ----------- | --------- | ---------- | ------- | -----
|
---- | ----------- | --------- | ---------- | ------- | -----
|
||||||
1 | 142 | 238 | 95.6 | 2987 | 154
|
1 | 142 | 219 | 91.8 | 2987 | 154
|
||||||
2 | 284 | 479 | 187 | 5658 | 295
|
2 | 284 | 422 | 181 | 5658 | 295
|
||||||
4 | 569 | 948 | 374 | 10509 | 584
|
4 | 569 | 852 | 356 | 10509 | 584
|
||||||
8 | 1131 | 1886 | 744 | 17822 | 1081
|
8 | 1131 | 1734 | 716 | 17822 | 1081
|
||||||
|
|
||||||
**Training real data**
|
**Training real data**
|
||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||||
---- | ----------- | --------- | ---------- | ------- | -----
|
---- | ----------- | --------- | ---------- | ------- | -----
|
||||||
1 | 142 | 239 | 95.5 | 2890 | 154
|
1 | 142 | 218 | 91.4 | 2890 | 154
|
||||||
2 | 278 | 468 | 187 | 4448 | 284
|
2 | 278 | 425 | 179 | 4448 | 284
|
||||||
4 | 551 | 938 | 373 | 7105 | 534
|
4 | 551 | 853 | 359 | 7105 | 534
|
||||||
8 | 1079 | 1802 | 721 | N/A | 898
|
8 | 1079 | 1630 | 708 | N/A | 898
|
||||||
|
|
||||||
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
||||||
above due to it maxing out the input pipeline.
|
above due to it maxing out the input pipeline.
|
||||||
@ -145,19 +147,19 @@ The results below are all with a batch size of 32.
|
|||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
||||||
---- | ----------- | --------- | ---------- | -----
|
---- | ----------- | --------- | ---------- | -----
|
||||||
1 | 128 | 210 | 85.3 | 144
|
1 | 128 | 195 | 82.7 | 144
|
||||||
2 | 259 | 412 | 166 | 281
|
2 | 259 | 368 | 160 | 281
|
||||||
4 | 520 | 827 | 330 | 549
|
4 | 520 | 768 | 317 | 549
|
||||||
8 | 995 | 1623 | 643 | 820
|
8 | 995 | 1485 | 632 | 820
|
||||||
|
|
||||||
**Training real data**
|
**Training real data**
|
||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
||||||
---- | ----------- | --------- | ---------- | -----
|
---- | ----------- | --------- | ---------- | -----
|
||||||
1 | 130 | 208 | 85.0 | 144
|
1 | 130 | 193 | 82.4 | 144
|
||||||
2 | 257 | 403 | 163 | 253
|
2 | 257 | 369 | 159 | 253
|
||||||
4 | 507 | 814 | 325 | 457
|
4 | 507 | 760 | 317 | 457
|
||||||
8 | 966 | 1525 | 641 | 690
|
8 | 966 | 1410 | 609 | 690
|
||||||
|
|
||||||
## Details for Google Compute Engine (NVIDIA® Tesla® K80)
|
## Details for Google Compute Engine (NVIDIA® Tesla® K80)
|
||||||
|
|
||||||
@ -168,11 +170,12 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | VGG16
|
|||||||
* **OS:** Ubuntu 16.04 LTS
|
* **OS:** Ubuntu 16.04 LTS
|
||||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||||
* **TensorFlow GitHub hash:** b1e174e
|
* **TensorFlow GitHub hash:** b1e174e
|
||||||
|
* **Benchmark GitHub hash:** 9165a70
|
||||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||||
//tensorflow/tools/pip_package:build_pip_package`
|
//tensorflow/tools/pip_package:build_pip_package`
|
||||||
* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s)
|
* **Disk:** 1.7 TB Shared SSD persistent disk (800 MB/s)
|
||||||
* **DataSet:** ImageNet
|
* **DataSet:** ImageNet
|
||||||
* **Test Date:** April 2017
|
* **Test Date:** May 2017
|
||||||
|
|
||||||
Batch size and optimizer used for each model are listed in the table below. In
|
Batch size and optimizer used for each model are listed in the table below. In
|
||||||
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
||||||
@ -198,19 +201,19 @@ The configuration used for each model was `variable_update` equal to
|
|||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||||
---- | ----------- | --------- | ---------- | ------- | -----
|
---- | ----------- | --------- | ---------- | ------- | -----
|
||||||
1 | 30.5 | 56.8 | 20.8 | 656 | 35.4
|
1 | 30.5 | 51.9 | 20.0 | 656 | 35.4
|
||||||
2 | 57.8 | 107 | 39.1 | 1209 | 64.8
|
2 | 57.8 | 99.0 | 38.2 | 1209 | 64.8
|
||||||
4 | 116 | 212 | 77.2 | 2328 | 120
|
4 | 116 | 195 | 75.8 | 2328 | 120
|
||||||
8 | 227 | 419 | 151 | 4640 | 234
|
8 | 227 | 387 | 148 | 4640 | 234
|
||||||
|
|
||||||
**Training real data**
|
**Training real data**
|
||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||||
---- | ----------- | --------- | ---------- | ------- | -----
|
---- | ----------- | --------- | ---------- | ------- | -----
|
||||||
1 | 30.6 | 56.7 | 20.7 | 639 | 34.2
|
1 | 30.6 | 51.2 | 20.0 | 639 | 34.2
|
||||||
2 | 58.4 | 107 | 39.0 | 1136 | 62.9
|
2 | 58.4 | 98.8 | 38.3 | 1136 | 62.9
|
||||||
4 | 115 | 211 | 77.3 | 2067 | 118
|
4 | 115 | 194 | 75.4 | 2067 | 118
|
||||||
8 | 225 | 422 | 151 | 4056 | 230
|
8 | 225 | 381 | 148 | 4056 | 230
|
||||||
|
|
||||||
### Other Results
|
### Other Results
|
||||||
|
|
||||||
@ -218,19 +221,19 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
|||||||
|
|
||||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||||
---- | --------------------------- | -------------------------
|
---- | --------------------------- | -------------------------
|
||||||
1 | 29.3 | 53.9
|
1 | 29.3 | 49.5
|
||||||
2 | 55.0 | 101
|
2 | 55.0 | 95.4
|
||||||
4 | 109 | 200
|
4 | 109 | 183
|
||||||
8 | 216 | 398
|
8 | 216 | 362
|
||||||
|
|
||||||
**Training real data**
|
**Training real data**
|
||||||
|
|
||||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||||
---- | --------------------------- | -------------------------
|
---- | --------------------------- | -------------------------
|
||||||
1 | 29.5 | 53.6
|
1 | 29.5 | 49.3
|
||||||
2 | 55.4 | 102
|
2 | 55.4 | 95.3
|
||||||
4 | 110 | 201
|
4 | 110 | 186
|
||||||
8 | 216 | 387
|
8 | 216 | 359
|
||||||
|
|
||||||
## Details for Amazon EC2 (NVIDIA® Tesla® K80)
|
## Details for Amazon EC2 (NVIDIA® Tesla® K80)
|
||||||
|
|
||||||
@ -241,12 +244,13 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
|||||||
* **OS:** Ubuntu 16.04 LTS
|
* **OS:** Ubuntu 16.04 LTS
|
||||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||||
* **TensorFlow GitHub hash:** b1e174e
|
* **TensorFlow GitHub hash:** b1e174e
|
||||||
|
* **Benchmark GitHub hash:** 9165a70
|
||||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||||
//tensorflow/tools/pip_package:build_pip_package`
|
//tensorflow/tools/pip_package:build_pip_package`
|
||||||
* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50
|
* **Disk:** 1TB Amazon EFS (burst 100 MiB/sec for 12 hours, continuous 50
|
||||||
MiB/sec)
|
MiB/sec)
|
||||||
* **DataSet:** ImageNet
|
* **DataSet:** ImageNet
|
||||||
* **Test Date:** April 2017
|
* **Test Date:** May 2017
|
||||||
|
|
||||||
Batch size and optimizer used for each model are listed in the table below. In
|
Batch size and optimizer used for each model are listed in the table below. In
|
||||||
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
||||||
@ -279,19 +283,19 @@ VGG16 | parameter_server | gpu
|
|||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||||
---- | ----------- | --------- | ---------- | ------- | -----
|
---- | ----------- | --------- | ---------- | ------- | -----
|
||||||
1 | 30.8 | 56.3 | 20.9 | 684 | 36.3
|
1 | 30.8 | 51.5 | 19.7 | 684 | 36.3
|
||||||
2 | 58.7 | 108 | 39.3 | 1244 | 69.4
|
2 | 58.7 | 98.0 | 37.6 | 1244 | 69.4
|
||||||
4 | 117 | 217 | 79.1 | 2479 | 141
|
4 | 117 | 195 | 74.9 | 2479 | 141
|
||||||
8 | 230 | 419 | 156 | 4853 | 260
|
8 | 230 | 384 | 149 | 4853 | 260
|
||||||
|
|
||||||
**Training real data**
|
**Training real data**
|
||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152 | Alexnet | VGG16
|
||||||
---- | ----------- | --------- | ---------- | ------- | -----
|
---- | ----------- | --------- | ---------- | ------- | -----
|
||||||
1 | 30.5 | 56.0 | 20.6 | 674 | 36.3
|
1 | 30.5 | 51.3 | 19.7 | 674 | 36.3
|
||||||
2 | 59.0 | 107 | 39.0 | 1227 | 67.5
|
2 | 59.0 | 94.9 | 38.2 | 1227 | 67.5
|
||||||
4 | 118 | 205 | 77.9 | 2201 | 136
|
4 | 118 | 188 | 75.2 | 2201 | 136
|
||||||
8 | 228 | 405 | 152 | N/A | 242
|
8 | 228 | 373 | 149 | N/A | 242
|
||||||
|
|
||||||
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
Training AlexNet with real data on 8 GPUs was excluded from the graph and table
|
||||||
above due to our EFS setup not providing enough throughput.
|
above due to our EFS setup not providing enough throughput.
|
||||||
@ -302,19 +306,19 @@ above due to our EFS setup not providing enough throughput.
|
|||||||
|
|
||||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||||
---- | --------------------------- | -------------------------
|
---- | --------------------------- | -------------------------
|
||||||
1 | 29.9 | 53.5
|
1 | 29.9 | 49.0
|
||||||
2 | 57.5 | 101
|
2 | 57.5 | 94.1
|
||||||
4 | 114 | 202
|
4 | 114 | 184
|
||||||
8 | 216 | 380
|
8 | 216 | 355
|
||||||
|
|
||||||
**Training real data**
|
**Training real data**
|
||||||
|
|
||||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||||
---- | --------------------------- | -------------------------
|
---- | --------------------------- | -------------------------
|
||||||
1 | 30.0 | 53.6
|
1 | 30.0 | 49.1
|
||||||
2 | 57.5 | 102
|
2 | 57.5 | 95.1
|
||||||
4 | 113 | 202
|
4 | 113 | 185
|
||||||
8 | 212 | 379
|
8 | 212 | 353
|
||||||
|
|
||||||
## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80)
|
## Details for Amazon EC2 Distributed (NVIDIA® Tesla® K80)
|
||||||
|
|
||||||
@ -325,11 +329,12 @@ GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
|||||||
* **OS:** Ubuntu 16.04 LTS
|
* **OS:** Ubuntu 16.04 LTS
|
||||||
* **CUDA / cuDNN:** 8.0 / 5.1
|
* **CUDA / cuDNN:** 8.0 / 5.1
|
||||||
* **TensorFlow GitHub hash:** b1e174e
|
* **TensorFlow GitHub hash:** b1e174e
|
||||||
|
* **Benchmark GitHub hash:** 9165a70
|
||||||
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
* **Build Command:** `bazel build -c opt --copt=-march="haswell" --config=cuda
|
||||||
//tensorflow/tools/pip_package:build_pip_package`
|
//tensorflow/tools/pip_package:build_pip_package`
|
||||||
* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec)
|
* **Disk:** 1.0 TB EFS (burst 100 MB/sec for 12 hours, continuous 50 MB/sec)
|
||||||
* **DataSet:** ImageNet
|
* **DataSet:** ImageNet
|
||||||
* **Test Date:** April 2017
|
* **Test Date:** May 2017
|
||||||
|
|
||||||
The batch size and optimizer used for the tests are listed in the table. In
|
The batch size and optimizer used for the tests are listed in the table. In
|
||||||
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
addition to the batch sizes listed in the table, InceptionV3 and ResNet-50 were
|
||||||
@ -343,11 +348,11 @@ Optimizer | sgd | sgd | sgd
|
|||||||
|
|
||||||
Configuration used for each model.
|
Configuration used for each model.
|
||||||
|
|
||||||
Model | variable_update | local_parameter_device
|
Model | variable_update | local_parameter_device | cross_replica_sync
|
||||||
----------- | ---------------------- | ----------------------
|
----------- | ---------------------- | ---------------------- | ------------------
|
||||||
InceptionV3 | distributed_replicated | n/a
|
InceptionV3 | distributed_replicated | n/a | True
|
||||||
ResNet-50 | distributed_replicated | n/a
|
ResNet-50 | distributed_replicated | n/a | True
|
||||||
ResNet-152 | distributed_replicated | n/a
|
ResNet-152 | distributed_replicated | n/a | True
|
||||||
|
|
||||||
To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also
|
To simplify server setup, EC2 instances (p2.8xlarge) running worker servers also
|
||||||
ran parameter servers. Equal numbers of parameter servers and work servers were
|
ran parameter servers. Equal numbers of parameter servers and work servers were
|
||||||
@ -371,11 +376,11 @@ used with the following exceptions:
|
|||||||
|
|
||||||
GPUs | InceptionV3 | ResNet-50 | ResNet-152
|
GPUs | InceptionV3 | ResNet-50 | ResNet-152
|
||||||
---- | ----------- | --------- | ----------
|
---- | ----------- | --------- | ----------
|
||||||
1 | 29.7 | 55.0 | 19.8
|
1 | 29.7 | 52.4 | 19.4
|
||||||
8 | 229 | 410 | 150
|
8 | 229 | 378 | 146
|
||||||
16 | 459 | 825 | 300
|
16 | 459 | 751 | 291
|
||||||
32 | 902 | 1468 | 575
|
32 | 902 | 1388 | 565
|
||||||
64 | 1783 | 3051 | 1004
|
64 | 1783 | 2744 | 981
|
||||||
|
|
||||||
### Other Results
|
### Other Results
|
||||||
|
|
||||||
@ -387,23 +392,23 @@ GPUs | InceptionV3 | ResNet-50 | ResNet-152
|
|||||||
|
|
||||||
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
GPUs | InceptionV3 (batch size 32) | ResNet-50 (batch size 32)
|
||||||
---- | --------------------------- | -------------------------
|
---- | --------------------------- | -------------------------
|
||||||
1 | 29.2 | 53.0
|
1 | 29.2 | 48.4
|
||||||
8 | 219 | 363
|
8 | 219 | 333
|
||||||
16 | 427 | 719
|
16 | 427 | 667
|
||||||
32 | 820 | 1265
|
32 | 820 | 1180
|
||||||
64 | 1608 | 2623
|
64 | 1608 | 2315
|
||||||
|
|
||||||
|
|
||||||
## Methodology
|
## Methodology
|
||||||
|
|
||||||
This [script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
|
This
|
||||||
|
[script](https://github.com/tensorflow/benchmarks/tree/master/scripts/tf_cnn_benchmarks)
|
||||||
was run on the various platforms to generate the above results.
|
was run on the various platforms to generate the above results.
|
||||||
@{$performance_models$High-Performance Models} details techniques in the script
|
@{$performance_models$High-Performance Models} details techniques in the script
|
||||||
along with examples of how to execute the script.
|
along with examples of how to execute the script.
|
||||||
|
|
||||||
In order to create results that are as repeatable as possible, each test was run
|
In order to create results that are as repeatable as possible, each test was run
|
||||||
5 times and then the times were averaged together. GPUs are run in their default
|
5 times and then the times were averaged together. GPUs are run in their default
|
||||||
state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
|
state on the given platform. For NVIDIA® Tesla® K80 this means leaving on [GPU
|
||||||
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/).
|
Boost](https://devblogs.nvidia.com/parallelforall/increase-performance-gpu-boost-k80-autoboost/).
|
||||||
For each test, 10 warmup steps are done and then the next 100 steps are
|
For each test, 10 warmup steps are done and then the next 100 steps are
|
||||||
averaged.
|
averaged.
|
||||||
|
@ -370,9 +370,8 @@ def create_bottleneck_file(bottleneck_path, image_lists, label_name, index,
|
|||||||
tf.logging.fatal('File does not exist %s', image_path)
|
tf.logging.fatal('File does not exist %s', image_path)
|
||||||
image_data = gfile.FastGFile(image_path, 'rb').read()
|
image_data = gfile.FastGFile(image_path, 'rb').read()
|
||||||
try:
|
try:
|
||||||
bottleneck_values = run_bottleneck_on_image(sess, image_data,
|
bottleneck_values = run_bottleneck_on_image(
|
||||||
jpeg_data_tensor,
|
sess, image_data, jpeg_data_tensor, bottleneck_tensor)
|
||||||
bottleneck_tensor)
|
|
||||||
except:
|
except:
|
||||||
raise RuntimeError('Error during processing file %s' % image_path)
|
raise RuntimeError('Error during processing file %s' % image_path)
|
||||||
|
|
||||||
|
@ -5583,6 +5583,74 @@ func ReadFile(scope *Scope, filename tf.Output) (contents tf.Output) {
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Store the input tensor in the state of the current session.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// value: The tensor to be stored.
|
||||||
|
//
|
||||||
|
// Returns The handle for the tensor stored in the session state, represented
|
||||||
|
// as a ResourceHandle object.
|
||||||
|
func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "GetSessionHandleV2",
|
||||||
|
Input: []tf.Input{
|
||||||
|
value,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Adjust the hue of one or more images.
|
||||||
|
//
|
||||||
|
// `images` is a tensor of at least 3 dimensions. The last dimension is
|
||||||
|
// interpretted as channels, and must be three.
|
||||||
|
//
|
||||||
|
// The input image is considered in the RGB colorspace. Conceptually, the RGB
|
||||||
|
// colors are first mapped into HSV. A delta is then applied all the hue values,
|
||||||
|
// and then remapped back to RGB colorspace.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// images: Images to adjust. At least 3-D.
|
||||||
|
// delta: A float delta to add to the hue.
|
||||||
|
//
|
||||||
|
// Returns The hue-adjusted image or images.
|
||||||
|
func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "AdjustHue",
|
||||||
|
Input: []tf.Input{
|
||||||
|
images, delta,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restore a Reader to its initial clean state.
|
||||||
|
//
|
||||||
|
// Arguments:
|
||||||
|
// reader_handle: Handle to a Reader.
|
||||||
|
//
|
||||||
|
// Returns the created operation.
|
||||||
|
func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "ReaderResetV2",
|
||||||
|
Input: []tf.Input{
|
||||||
|
reader_handle,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
return scope.AddOperation(opspec)
|
||||||
|
}
|
||||||
|
|
||||||
// Computes softmax cross entropy cost and gradients to backpropagate.
|
// Computes softmax cross entropy cost and gradients to backpropagate.
|
||||||
//
|
//
|
||||||
// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
|
// Unlike `SoftmaxCrossEntropyWithLogits`, this operation does not accept
|
||||||
@ -19039,6 +19107,27 @@ func Igamma(scope *Scope, a tf.Output, x tf.Output) (z tf.Output) {
|
|||||||
return op.Output(0)
|
return op.Output(0)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Computes arctangent of `y/x` element-wise, respecting signs of the arguments.
|
||||||
|
//
|
||||||
|
// This is the angle \( \theta \in [-\pi, \pi] \) such that
|
||||||
|
// \[ x = r \cos(\theta) \]
|
||||||
|
// and
|
||||||
|
// \[ y = r \sin(\theta) \]
|
||||||
|
// where \(r = \sqrt(x^2 + y^2) \).
|
||||||
|
func Atan2(scope *Scope, y tf.Output, x tf.Output) (z tf.Output) {
|
||||||
|
if scope.Err() != nil {
|
||||||
|
return
|
||||||
|
}
|
||||||
|
opspec := tf.OpSpec{
|
||||||
|
Type: "Atan2",
|
||||||
|
Input: []tf.Input{
|
||||||
|
y, x,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
op := scope.AddOperation(opspec)
|
||||||
|
return op.Output(0)
|
||||||
|
}
|
||||||
|
|
||||||
// Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
|
// Compute the regularized incomplete beta integral \\(I_x(a, b)\\).
|
||||||
//
|
//
|
||||||
// The regularized incomplete beta integral is defined as:
|
// The regularized incomplete beta integral is defined as:
|
||||||
@ -21627,71 +21716,3 @@ func SoftmaxCrossEntropyWithLogits(scope *Scope, features tf.Output, labels tf.O
|
|||||||
op := scope.AddOperation(opspec)
|
op := scope.AddOperation(opspec)
|
||||||
return op.Output(0), op.Output(1)
|
return op.Output(0), op.Output(1)
|
||||||
}
|
}
|
||||||
|
|
||||||
// Store the input tensor in the state of the current session.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
// value: The tensor to be stored.
|
|
||||||
//
|
|
||||||
// Returns The handle for the tensor stored in the session state, represented
|
|
||||||
// as a ResourceHandle object.
|
|
||||||
func GetSessionHandleV2(scope *Scope, value tf.Output) (handle tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "GetSessionHandleV2",
|
|
||||||
Input: []tf.Input{
|
|
||||||
value,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Adjust the hue of one or more images.
|
|
||||||
//
|
|
||||||
// `images` is a tensor of at least 3 dimensions. The last dimension is
|
|
||||||
// interpretted as channels, and must be three.
|
|
||||||
//
|
|
||||||
// The input image is considered in the RGB colorspace. Conceptually, the RGB
|
|
||||||
// colors are first mapped into HSV. A delta is then applied all the hue values,
|
|
||||||
// and then remapped back to RGB colorspace.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
// images: Images to adjust. At least 3-D.
|
|
||||||
// delta: A float delta to add to the hue.
|
|
||||||
//
|
|
||||||
// Returns The hue-adjusted image or images.
|
|
||||||
func AdjustHue(scope *Scope, images tf.Output, delta tf.Output) (output tf.Output) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "AdjustHue",
|
|
||||||
Input: []tf.Input{
|
|
||||||
images, delta,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
op := scope.AddOperation(opspec)
|
|
||||||
return op.Output(0)
|
|
||||||
}
|
|
||||||
|
|
||||||
// Restore a Reader to its initial clean state.
|
|
||||||
//
|
|
||||||
// Arguments:
|
|
||||||
// reader_handle: Handle to a Reader.
|
|
||||||
//
|
|
||||||
// Returns the created operation.
|
|
||||||
func ReaderResetV2(scope *Scope, reader_handle tf.Output) (o *tf.Operation) {
|
|
||||||
if scope.Err() != nil {
|
|
||||||
return
|
|
||||||
}
|
|
||||||
opspec := tf.OpSpec{
|
|
||||||
Type: "ReaderResetV2",
|
|
||||||
Input: []tf.Input{
|
|
||||||
reader_handle,
|
|
||||||
},
|
|
||||||
}
|
|
||||||
return scope.AddOperation(opspec)
|
|
||||||
}
|
|
||||||
|
@ -1 +0,0 @@
|
|||||||
#include "unsupported/Eigen/CXX11/ThreadPool"
|
|
@ -436,13 +436,13 @@ class EstimatorTrainTest(test.TestCase):
|
|||||||
model_dir=model_dir1,
|
model_dir=model_dir1,
|
||||||
model_fn=model_fn_global_step_incrementer)
|
model_fn=model_fn_global_step_incrementer)
|
||||||
est1.train(dummy_input_fn, steps=5)
|
est1.train(dummy_input_fn, steps=5)
|
||||||
|
|
||||||
# We have to clear the cache before we can rename the directory,
|
# We have to clear the cache before we can rename the directory,
|
||||||
# otherwise open file handles will prevent the delete on Windows.
|
# otherwise open file handles will prevent the delete on Windows.
|
||||||
writer_cache.FileWriterCache.clear()
|
writer_cache.FileWriterCache.clear()
|
||||||
model_dir2 = os.path.join(tmpdir, 'model_dir2')
|
model_dir2 = os.path.join(tmpdir, 'model_dir2')
|
||||||
os.renames(model_dir1, model_dir2)
|
os.renames(model_dir1, model_dir2)
|
||||||
|
|
||||||
est2 = estimator.Estimator(
|
est2 = estimator.Estimator(
|
||||||
model_dir=model_dir2,
|
model_dir=model_dir2,
|
||||||
model_fn=model_fn_global_step_incrementer)
|
model_fn=model_fn_global_step_incrementer)
|
||||||
|
@ -129,6 +129,7 @@ from tensorflow.python.framework import ops
|
|||||||
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
|
||||||
from tensorflow.python.framework import tensor_shape
|
from tensorflow.python.framework import tensor_shape
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
|
from tensorflow.python.ops import check_ops
|
||||||
from tensorflow.python.ops import embedding_ops
|
from tensorflow.python.ops import embedding_ops
|
||||||
from tensorflow.python.ops import init_ops
|
from tensorflow.python.ops import init_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
@ -656,6 +657,44 @@ def categorical_column_with_vocabulary_list(
|
|||||||
default_value=default_value)
|
default_value=default_value)
|
||||||
|
|
||||||
|
|
||||||
|
def categorical_column_with_identity(key, num_buckets, default_value=None):
|
||||||
|
"""A `_CategoricalColumn` that returns identity values.
|
||||||
|
|
||||||
|
Use this when your inputs are integers in the range `[0, num_buckets)`. Values
|
||||||
|
outside this range will result in `default_value` if specified, otherwise it
|
||||||
|
will fail.
|
||||||
|
|
||||||
|
Inputs can be either `Tensor` or `SparseTensor`.
|
||||||
|
```
|
||||||
|
|
||||||
|
Args:
|
||||||
|
key: A unique string identifying the input feature. It is used as the
|
||||||
|
column name and the dictionary key for feature parsing configs, feature
|
||||||
|
`Tensor` objects, and feature columns.
|
||||||
|
num_buckets: Range of inputs and outputs is `[0, num_buckets)`.
|
||||||
|
default_value: If `None`, this column's graph operations will fail for
|
||||||
|
out-of-range inputs. Otherwise, this value must be in the range
|
||||||
|
`[0, num_buckets)`, and will replace inputs in that range.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A `_CategoricalColumn` that returns identity values.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
ValueError: if `num_buckets` is less than one.
|
||||||
|
ValueError: if `default_value` is not in range `[0, num_buckets)`.
|
||||||
|
"""
|
||||||
|
if num_buckets < 1:
|
||||||
|
raise ValueError(
|
||||||
|
'num_buckets {} < 1, column_name {}'.format(num_buckets, key))
|
||||||
|
if (default_value is not None) and (
|
||||||
|
(default_value < 0) or (default_value >= num_buckets)):
|
||||||
|
raise ValueError(
|
||||||
|
'default_value {} not in range [0, {}), column_name {}'.format(
|
||||||
|
default_value, num_buckets, key))
|
||||||
|
return _IdentityCategoricalColumn(
|
||||||
|
key=key, num_buckets=num_buckets, default_value=default_value)
|
||||||
|
|
||||||
|
|
||||||
class _FeatureColumn(object):
|
class _FeatureColumn(object):
|
||||||
"""Represents a feature column abstraction.
|
"""Represents a feature column abstraction.
|
||||||
|
|
||||||
@ -1384,6 +1423,69 @@ class _VocabularyListCategoricalColumn(
|
|||||||
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||||
|
|
||||||
|
|
||||||
|
class _IdentityCategoricalColumn(
|
||||||
|
_CategoricalColumn,
|
||||||
|
collections.namedtuple('_IdentityCategoricalColumn', (
|
||||||
|
'key', 'num_buckets', 'default_value'
|
||||||
|
))):
|
||||||
|
|
||||||
|
"""See `categorical_column_with_identity`."""
|
||||||
|
|
||||||
|
@property
|
||||||
|
def name(self):
|
||||||
|
return self.key
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _parse_example_config(self):
|
||||||
|
return {self.key: parsing_ops.VarLenFeature(dtypes.int64)}
|
||||||
|
|
||||||
|
def _transform_feature(self, inputs):
|
||||||
|
input_tensor = _to_sparse_input(inputs.get(self.key))
|
||||||
|
|
||||||
|
if not input_tensor.dtype.is_integer:
|
||||||
|
raise ValueError(
|
||||||
|
'Invalid input, not integer. key: {} dtype: {}'.format(
|
||||||
|
self.key, input_tensor.dtype))
|
||||||
|
|
||||||
|
values = math_ops.to_int64(input_tensor.values, name='values')
|
||||||
|
num_buckets = math_ops.to_int64(self.num_buckets, name='num_buckets')
|
||||||
|
zero = math_ops.to_int64(0, name='zero')
|
||||||
|
if self.default_value is None:
|
||||||
|
# Fail if values are out-of-range.
|
||||||
|
assert_less = check_ops.assert_less(
|
||||||
|
values, num_buckets, data=(values, num_buckets),
|
||||||
|
name='assert_less_than_num_buckets')
|
||||||
|
assert_greater = check_ops.assert_greater_equal(
|
||||||
|
values, zero, data=(values,),
|
||||||
|
name='assert_greater_or_equal_0')
|
||||||
|
with ops.control_dependencies((assert_less, assert_greater)):
|
||||||
|
values = array_ops.identity(values)
|
||||||
|
else:
|
||||||
|
# Assign default for out-of-range values.
|
||||||
|
values = array_ops.where(
|
||||||
|
math_ops.logical_or(
|
||||||
|
values < zero, values >= num_buckets, name='out_of_range'),
|
||||||
|
array_ops.fill(
|
||||||
|
dims=array_ops.shape(values),
|
||||||
|
value=math_ops.to_int64(self.default_value),
|
||||||
|
name='default_values'),
|
||||||
|
values)
|
||||||
|
|
||||||
|
return sparse_tensor_lib.SparseTensor(
|
||||||
|
indices=input_tensor.indices,
|
||||||
|
values=values,
|
||||||
|
dense_shape=input_tensor.dense_shape)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def _num_buckets(self):
|
||||||
|
"""Returns number of buckets in this sparse feature."""
|
||||||
|
return self.num_buckets
|
||||||
|
|
||||||
|
def _get_sparse_tensors(
|
||||||
|
self, inputs, weight_collections=None, trainable=None):
|
||||||
|
return _CategoricalColumn.IdWeightPair(inputs.get(self), None)
|
||||||
|
|
||||||
|
|
||||||
# TODO(zakaria): Move this to embedding_ops and make it public.
|
# TODO(zakaria): Move this to embedding_ops and make it public.
|
||||||
def _safe_embedding_lookup_sparse(embedding_weights,
|
def _safe_embedding_lookup_sparse(embedding_weights,
|
||||||
sparse_ids,
|
sparse_ids,
|
||||||
|
@ -31,6 +31,7 @@ from tensorflow.python.framework import dtypes
|
|||||||
from tensorflow.python.framework import errors
|
from tensorflow.python.framework import errors
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.framework import sparse_tensor
|
from tensorflow.python.framework import sparse_tensor
|
||||||
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import lookup_ops
|
from tensorflow.python.ops import lookup_ops
|
||||||
from tensorflow.python.ops import parsing_ops
|
from tensorflow.python.ops import parsing_ops
|
||||||
from tensorflow.python.ops import variable_scope
|
from tensorflow.python.ops import variable_scope
|
||||||
@ -1828,5 +1829,198 @@ class VocabularyListCategoricalColumnTest(test.TestCase):
|
|||||||
self.assertAllClose(((3.,), (1.,)), predictions.eval())
|
self.assertAllClose(((3.,), (1.,)), predictions.eval())
|
||||||
|
|
||||||
|
|
||||||
|
class IdentityCategoricalColumnTest(test.TestCase):
|
||||||
|
|
||||||
|
def test_constructor(self):
|
||||||
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
self.assertEqual('aaa', column.name)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self.assertEqual(3, column._num_buckets)
|
||||||
|
self.assertEqual({
|
||||||
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
|
}, column._parse_example_config)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_deep_copy(self):
|
||||||
|
"""Tests deepcopy of categorical_column_with_hash_bucket."""
|
||||||
|
original = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
for column in (original, copy.deepcopy(original)):
|
||||||
|
self.assertEqual('aaa', column.name)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
self.assertEqual(3, column._num_buckets)
|
||||||
|
self.assertEqual({
|
||||||
|
'aaa': parsing_ops.VarLenFeature(dtypes.int64)
|
||||||
|
}, column._parse_example_config)
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_invalid_num_buckets_zero(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'num_buckets 0 < 1'):
|
||||||
|
fc.categorical_column_with_identity(key='aaa', num_buckets=0)
|
||||||
|
|
||||||
|
def test_invalid_num_buckets_negative(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'num_buckets -1 < 1'):
|
||||||
|
fc.categorical_column_with_identity(key='aaa', num_buckets=-1)
|
||||||
|
|
||||||
|
def test_invalid_default_value_too_small(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'default_value -1 not in range'):
|
||||||
|
fc.categorical_column_with_identity(
|
||||||
|
key='aaa', num_buckets=3, default_value=-1)
|
||||||
|
|
||||||
|
def test_invalid_default_value_too_big(self):
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'default_value 3 not in range'):
|
||||||
|
fc.categorical_column_with_identity(
|
||||||
|
key='aaa', num_buckets=3, default_value=3)
|
||||||
|
|
||||||
|
def test_invalid_input_dtype(self):
|
||||||
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=('omar', 'stringer', 'marlo'),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
with self.assertRaisesRegexp(ValueError, 'Invalid input, not integer'):
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
column._get_sparse_tensors(fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
|
||||||
|
def test_get_sparse_tensors(self):
|
||||||
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=(0, 1, 0),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=inputs.indices,
|
||||||
|
values=np.array((0, 1, 0), dtype=np.int64),
|
||||||
|
dense_shape=inputs.dense_shape),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_dense_input(self):
|
||||||
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(fc._LazyBuilder({
|
||||||
|
'aaa': ((0, -1), (1, 0))
|
||||||
|
}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=np.array((0, 1, 0), dtype=np.int64),
|
||||||
|
dense_shape=(2, 2)),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_with_inputs_too_small(self):
|
||||||
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=(1, -1, 0),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors.OpError, 'assert_greater_or_equal_0'):
|
||||||
|
id_weight_pair.id_tensor.eval()
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_with_inputs_too_big(self):
|
||||||
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=(1, 99, 0),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
with self.assertRaisesRegexp(
|
||||||
|
errors.OpError, 'assert_less_than_num_buckets'):
|
||||||
|
id_weight_pair.id_tensor.eval()
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_with_default_value(self):
|
||||||
|
column = fc.categorical_column_with_identity(
|
||||||
|
key='aaa', num_buckets=4, default_value=3)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=(1, -1, 99),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=inputs.indices,
|
||||||
|
values=np.array((1, 3, 3), dtype=np.int64),
|
||||||
|
dense_shape=inputs.dense_shape),
|
||||||
|
id_weight_pair.id_tensor.eval())
|
||||||
|
|
||||||
|
def test_get_sparse_tensors_with_default_value_and_placeholder_inputs(self):
|
||||||
|
column = fc.categorical_column_with_identity(
|
||||||
|
key='aaa', num_buckets=4, default_value=3)
|
||||||
|
input_indices = array_ops.placeholder(dtype=dtypes.int64)
|
||||||
|
input_values = array_ops.placeholder(dtype=dtypes.int32)
|
||||||
|
input_shape = array_ops.placeholder(dtype=dtypes.int64)
|
||||||
|
inputs = sparse_tensor.SparseTensorValue(
|
||||||
|
indices=input_indices,
|
||||||
|
values=input_values,
|
||||||
|
dense_shape=input_shape)
|
||||||
|
# pylint: disable=protected-access
|
||||||
|
id_weight_pair = column._get_sparse_tensors(
|
||||||
|
fc._LazyBuilder({'aaa': inputs}))
|
||||||
|
# pylint: enable=protected-access
|
||||||
|
self.assertIsNone(id_weight_pair.weight_tensor)
|
||||||
|
with _initialized_session():
|
||||||
|
_assert_sparse_tensor_value(
|
||||||
|
self,
|
||||||
|
sparse_tensor.SparseTensorValue(
|
||||||
|
indices=np.array(((0, 0), (1, 0), (1, 1)), dtype=np.int64),
|
||||||
|
values=np.array((1, 3, 3), dtype=np.int64),
|
||||||
|
dense_shape=np.array((2, 2), dtype=np.int64)),
|
||||||
|
id_weight_pair.id_tensor.eval(feed_dict={
|
||||||
|
input_indices: ((0, 0), (1, 0), (1, 1)),
|
||||||
|
input_values: (1, -1, 99),
|
||||||
|
input_shape: (2, 2),
|
||||||
|
}))
|
||||||
|
|
||||||
|
def test_make_linear_model(self):
|
||||||
|
column = fc.categorical_column_with_identity(key='aaa', num_buckets=3)
|
||||||
|
self.assertEqual(3, column._num_buckets)
|
||||||
|
with ops.Graph().as_default():
|
||||||
|
predictions = fc.make_linear_model({
|
||||||
|
column.name: sparse_tensor.SparseTensorValue(
|
||||||
|
indices=((0, 0), (1, 0), (1, 1)),
|
||||||
|
values=(0, 2, 1),
|
||||||
|
dense_shape=(2, 2))
|
||||||
|
}, (column,))
|
||||||
|
bias = get_linear_model_bias()
|
||||||
|
weight_var = get_linear_model_column_var(column)
|
||||||
|
with _initialized_session():
|
||||||
|
self.assertAllClose((0.,), bias.eval())
|
||||||
|
self.assertAllClose(((0.,), (0.,), (0.,)), weight_var.eval())
|
||||||
|
self.assertAllClose(((0.,), (0.,)), predictions.eval())
|
||||||
|
weight_var.assign(((1.,), (2.,), (3.,))).eval()
|
||||||
|
# weight_var[0] = 1
|
||||||
|
# weight_var[2] + weight_var[1] = 3+2 = 5
|
||||||
|
self.assertAllClose(((1.,), (5.,)), predictions.eval())
|
||||||
|
|
||||||
|
|
||||||
if __name__ == '__main__':
|
if __name__ == '__main__':
|
||||||
test.main()
|
test.main()
|
||||||
|
@ -113,18 +113,19 @@ def _add_op_node(op, func, input_dict):
|
|||||||
node_def = func.node_def[-1]
|
node_def = func.node_def[-1]
|
||||||
for i in range(len(node_def.input)):
|
for i in range(len(node_def.input)):
|
||||||
if not node_def.input[i].startswith("^"):
|
if not node_def.input[i].startswith("^"):
|
||||||
assert node_def.input[i] in input_dict, (
|
assert node_def.input[i] in input_dict, ("%s missing from %s" %
|
||||||
"%s missing from %s" % (node_def.input[i], input_dict.items()))
|
(node_def.input[i],
|
||||||
|
input_dict.items()))
|
||||||
node_def.input[i] = input_dict[node_def.input[i]]
|
node_def.input[i] = input_dict[node_def.input[i]]
|
||||||
|
|
||||||
|
|
||||||
def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
def _graph_to_function_def(graph, operations, inputs, outputs, out_names=None):
|
||||||
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
"""Returns `graph` as a `FunctionDef` protocol buffer.
|
||||||
|
|
||||||
This method creates a [`FunctionDef`](
|
This method creates a [`FunctionDef`](
|
||||||
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
|
https://www.tensorflow.org/code/tensorflow/core/framework/function.proto)
|
||||||
protocol buffer that contains all the ops present in the graph. The
|
protocol buffer that contains all the ops in `operations`. The
|
||||||
graph effectively becomes the body of the function.
|
operations become the body of the function.
|
||||||
|
|
||||||
The arguments `inputs` and `outputs` will be listed as the inputs
|
The arguments `inputs` and `outputs` will be listed as the inputs
|
||||||
and outputs tensors of the function. They must be lists of
|
and outputs tensors of the function. They must be lists of
|
||||||
@ -132,6 +133,8 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
graph: Graph.
|
graph: Graph.
|
||||||
|
operations: the operations to put in the function. Must be a subset of
|
||||||
|
the operations in the graph.
|
||||||
inputs: List of tensors. Inputs to the function.
|
inputs: List of tensors. Inputs to the function.
|
||||||
outputs: List of tensors. Outputs of the function.
|
outputs: List of tensors. Outputs of the function.
|
||||||
out_names: Optional list of string names for the outputs.
|
out_names: Optional list of string names for the outputs.
|
||||||
@ -145,12 +148,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
|||||||
func = function_pb2.FunctionDef()
|
func = function_pb2.FunctionDef()
|
||||||
func.signature.name = "_"
|
func.signature.name = "_"
|
||||||
used_names = set()
|
used_names = set()
|
||||||
func.signature.input_arg.extend([_tensor_to_argdef(i, used_names=used_names)
|
func.signature.input_arg.extend(
|
||||||
for i in inputs])
|
[_tensor_to_argdef(i, used_names=used_names) for i in inputs])
|
||||||
if out_names is None:
|
if out_names is None:
|
||||||
used_names = set()
|
used_names = set()
|
||||||
func.signature.output_arg.extend([
|
func.signature.output_arg.extend(
|
||||||
_tensor_to_argdef(o, used_names=used_names) for o in outputs])
|
[_tensor_to_argdef(o, used_names=used_names) for o in outputs])
|
||||||
elif len(outputs) != len(out_names):
|
elif len(outputs) != len(out_names):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Length of out_names (%d) does not match number of outputs (%d): %s" %
|
"Length of out_names (%d) does not match number of outputs (%d): %s" %
|
||||||
@ -159,12 +162,12 @@ def _graph_to_function_def(graph, inputs, outputs, out_names=None):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
|
"Must not have duplicates in out_names: %s" % ", ".join(out_names))
|
||||||
else:
|
else:
|
||||||
func.signature.output_arg.extend([
|
func.signature.output_arg.extend(
|
||||||
_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
[_tensor_to_argdef(o, name=n) for o, n in zip(outputs, out_names)])
|
||||||
func_arg_placeholders = set([i.name for i in inputs])
|
func_arg_placeholders = set([i.name for i in inputs])
|
||||||
input_dict = _create_input_dict(graph, func_arg_placeholders)
|
input_dict = _create_input_dict(graph, func_arg_placeholders)
|
||||||
|
|
||||||
for op in graph.get_operations():
|
for op in operations:
|
||||||
if _is_in_placeholders(op, func_arg_placeholders):
|
if _is_in_placeholders(op, func_arg_placeholders):
|
||||||
continue
|
continue
|
||||||
_add_op_node(op, func, input_dict)
|
_add_op_node(op, func, input_dict)
|
||||||
@ -295,17 +298,18 @@ class _FuncGraph(ops.Graph):
|
|||||||
self.extra_args = []
|
self.extra_args = []
|
||||||
self.extra_vars = []
|
self.extra_vars = []
|
||||||
|
|
||||||
def getvar(self,
|
def getvar(
|
||||||
getter,
|
self,
|
||||||
name,
|
getter,
|
||||||
shape=None,
|
name,
|
||||||
dtype=None,
|
shape=None,
|
||||||
initializer=None,
|
dtype=None,
|
||||||
reuse=None,
|
initializer=None,
|
||||||
trainable=True,
|
reuse=None,
|
||||||
collections=None, # pylint: disable=redefined-outer-name
|
trainable=True,
|
||||||
use_resource=None,
|
collections=None, # pylint: disable=redefined-outer-name
|
||||||
**kwargs):
|
use_resource=None,
|
||||||
|
**kwargs):
|
||||||
"""A custom variable getter."""
|
"""A custom variable getter."""
|
||||||
# Here, we switch the default graph to the outer graph and ask the
|
# Here, we switch the default graph to the outer graph and ask the
|
||||||
# variable scope in which the function is defined to give us the
|
# variable scope in which the function is defined to give us the
|
||||||
@ -538,20 +542,23 @@ class _DefinedFunction(object):
|
|||||||
|
|
||||||
# Build the FunctionDef
|
# Build the FunctionDef
|
||||||
self._definition = _graph_to_function_def(
|
self._definition = _graph_to_function_def(
|
||||||
temp_graph, inputs, outputs, out_names=self._out_names)
|
temp_graph,
|
||||||
|
temp_graph.get_operations(),
|
||||||
|
inputs,
|
||||||
|
outputs,
|
||||||
|
out_names=self._out_names)
|
||||||
|
|
||||||
# Extra kwargs are treated as attrs on the function def.
|
# Extra kwargs are treated as attrs on the function def.
|
||||||
sig_pre_func_name = self._func_name or _get_func_name(self._func)
|
sig_pre_func_name = self._func_name or _get_func_name(self._func)
|
||||||
kwargs_attr = _parse_kwargs_as_attrs(
|
kwargs_attr = _parse_kwargs_as_attrs(sig_pre_func_name,
|
||||||
sig_pre_func_name, **self._extra_kwargs)
|
**self._extra_kwargs)
|
||||||
for k in kwargs_attr:
|
for k in kwargs_attr:
|
||||||
self._definition.attr[k].CopyFrom(kwargs_attr[k])
|
self._definition.attr[k].CopyFrom(kwargs_attr[k])
|
||||||
|
|
||||||
# Hash the definition and its dependencies.
|
# Hash the definition and its dependencies.
|
||||||
self._hash_str = self._create_hash_str(
|
self._hash_str = self._create_hash_str(
|
||||||
self._definition.signature.input_arg,
|
self._definition.signature.input_arg,
|
||||||
self._definition.signature.output_arg,
|
self._definition.signature.output_arg, self._definition.node_def)
|
||||||
self._definition.node_def)
|
|
||||||
|
|
||||||
# Finally, we decide the function name to use. If not specified,
|
# Finally, we decide the function name to use. If not specified,
|
||||||
# make up something which is almost certainly unique (but deterministic).
|
# make up something which is almost certainly unique (but deterministic).
|
||||||
@ -658,8 +665,8 @@ def _from_definition(fdef, grad_func=None):
|
|||||||
# have access to such a callable here).
|
# have access to such a callable here).
|
||||||
func = None
|
func = None
|
||||||
argnames = [arg.name for arg in fdef.signature.input_arg]
|
argnames = [arg.name for arg in fdef.signature.input_arg]
|
||||||
input_types = tuple(dtypes.as_dtype(arg.type)
|
input_types = tuple(
|
||||||
for arg in fdef.signature.input_arg)
|
dtypes.as_dtype(arg.type) for arg in fdef.signature.input_arg)
|
||||||
func_name = fdef.signature.name
|
func_name = fdef.signature.name
|
||||||
# Note: FunctionDefs do not include python gradient functions, so if the
|
# Note: FunctionDefs do not include python gradient functions, so if the
|
||||||
# original _DefinedFunction included one it will not be reflected here.
|
# original _DefinedFunction included one it will not be reflected here.
|
||||||
@ -675,8 +682,7 @@ def _from_definition(fdef, grad_func=None):
|
|||||||
result._extra_inputs = []
|
result._extra_inputs = []
|
||||||
result._hash_str = result._create_hash_str(
|
result._hash_str = result._create_hash_str(
|
||||||
result._definition.signature.input_arg,
|
result._definition.signature.input_arg,
|
||||||
result._definition.signature.output_arg,
|
result._definition.signature.output_arg, result._definition.node_def)
|
||||||
result._definition.node_def)
|
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
return result
|
return result
|
||||||
|
|
||||||
@ -696,7 +702,8 @@ def _from_library(lib):
|
|||||||
Raises:
|
Raises:
|
||||||
ValueError: `lib` is invalid
|
ValueError: `lib` is invalid
|
||||||
"""
|
"""
|
||||||
if not lib.function and not lib.gradient: return []
|
if not lib.function and not lib.gradient:
|
||||||
|
return []
|
||||||
|
|
||||||
# function name -> FunctionDef proto
|
# function name -> FunctionDef proto
|
||||||
funcs = {fdef.signature.name: fdef for fdef in lib.function}
|
funcs = {fdef.signature.name: fdef for fdef in lib.function}
|
||||||
@ -720,8 +727,9 @@ def _from_library(lib):
|
|||||||
grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
|
grad_to_funcs[gdef.gradient_func].append(gdef.function_name)
|
||||||
|
|
||||||
# Start with functions without gradients
|
# Start with functions without gradients
|
||||||
ready = [fdef for fdef in lib.function
|
ready = [
|
||||||
if func_to_grad[fdef.signature.name] is None]
|
fdef for fdef in lib.function if func_to_grad[fdef.signature.name] is None
|
||||||
|
]
|
||||||
if not ready:
|
if not ready:
|
||||||
raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
|
raise ValueError("FunctionDefLibrary contains cyclic gradient functions!\n"
|
||||||
+ str(lib))
|
+ str(lib))
|
||||||
@ -733,7 +741,8 @@ def _from_library(lib):
|
|||||||
name = fdef.signature.name
|
name = fdef.signature.name
|
||||||
|
|
||||||
grad = initialized.get(func_to_grad[name])
|
grad = initialized.get(func_to_grad[name])
|
||||||
if func_to_grad[name]: assert grad
|
if func_to_grad[name]:
|
||||||
|
assert grad
|
||||||
defined_func = _from_definition(fdef, grad_func=grad)
|
defined_func = _from_definition(fdef, grad_func=grad)
|
||||||
initialized[name] = defined_func
|
initialized[name] = defined_func
|
||||||
|
|
||||||
@ -835,10 +844,15 @@ class _OverloadedFunction(object):
|
|||||||
name = self._func_name
|
name = self._func_name
|
||||||
if name is not None:
|
if name is not None:
|
||||||
name = "_".join([name, key])
|
name = "_".join([name, key])
|
||||||
defined = _DefinedFunction(self._func, self._argnames, input_types, name,
|
defined = _DefinedFunction(
|
||||||
None, self._python_grad_func,
|
self._func,
|
||||||
out_names=self._out_names,
|
self._argnames,
|
||||||
**self._extra_kwargs)
|
input_types,
|
||||||
|
name,
|
||||||
|
None,
|
||||||
|
self._python_grad_func,
|
||||||
|
out_names=self._out_names,
|
||||||
|
**self._extra_kwargs)
|
||||||
_ = defined.name # Fully instantiate the function definition.
|
_ = defined.name # Fully instantiate the function definition.
|
||||||
if self._grad_func:
|
if self._grad_func:
|
||||||
# If _grad_func is given, it is another
|
# If _grad_func is given, it is another
|
||||||
@ -849,8 +863,8 @@ class _OverloadedFunction(object):
|
|||||||
for _ in defined.definition.signature.output_arg
|
for _ in defined.definition.signature.output_arg
|
||||||
]
|
]
|
||||||
# pylint: disable=protected-access
|
# pylint: disable=protected-access
|
||||||
defined._grad_func = self._grad_func.instantiate(input_types +
|
defined._grad_func = self._grad_func.instantiate(
|
||||||
output_types)
|
input_types + output_types)
|
||||||
# pylint: enable=protected-access
|
# pylint: enable=protected-access
|
||||||
self._overload[key] = defined
|
self._overload[key] = defined
|
||||||
return defined
|
return defined
|
||||||
@ -981,22 +995,36 @@ class Defun(object):
|
|||||||
raise ValueError(
|
raise ValueError(
|
||||||
"The function has fewer arguments than the number of specified "
|
"The function has fewer arguments than the number of specified "
|
||||||
"input types.")
|
"input types.")
|
||||||
return _DefinedFunction(func, argnames, self._input_types,
|
return _DefinedFunction(
|
||||||
self._func_name, self._grad_func,
|
func,
|
||||||
self._python_grad_func,
|
argnames,
|
||||||
out_names=self._out_names, **self._extra_kwargs)
|
self._input_types,
|
||||||
|
self._func_name,
|
||||||
|
self._grad_func,
|
||||||
|
self._python_grad_func,
|
||||||
|
out_names=self._out_names,
|
||||||
|
**self._extra_kwargs)
|
||||||
|
|
||||||
# 'func' expects no arguments and input types is an empty list.
|
# 'func' expects no arguments and input types is an empty list.
|
||||||
if min_args == 0 and max_args == 0:
|
if min_args == 0 and max_args == 0:
|
||||||
return _DefinedFunction(func, [], [], self._func_name, self._grad_func,
|
return _DefinedFunction(
|
||||||
self._python_grad_func,
|
func, [], [],
|
||||||
out_names=self._out_names, **self._extra_kwargs)
|
self._func_name,
|
||||||
|
self._grad_func,
|
||||||
|
self._python_grad_func,
|
||||||
|
out_names=self._out_names,
|
||||||
|
**self._extra_kwargs)
|
||||||
|
|
||||||
# Input types are unknown. It's an overloaded function and hence
|
# Input types are unknown. It's an overloaded function and hence
|
||||||
# its definition needs to be deferred until it's called.
|
# its definition needs to be deferred until it's called.
|
||||||
return _OverloadedFunction(func, argnames, self._func_name, self._grad_func,
|
return _OverloadedFunction(
|
||||||
self._python_grad_func,
|
func,
|
||||||
out_names=self._out_names, **self._extra_kwargs)
|
argnames,
|
||||||
|
self._func_name,
|
||||||
|
self._grad_func,
|
||||||
|
self._python_grad_func,
|
||||||
|
out_names=self._out_names,
|
||||||
|
**self._extra_kwargs)
|
||||||
|
|
||||||
|
|
||||||
class Declare(object):
|
class Declare(object):
|
||||||
@ -1039,8 +1067,10 @@ class Declare(object):
|
|||||||
names = [n for n, t in args]
|
names = [n for n, t in args]
|
||||||
if len(names) != len(set(names)):
|
if len(names) != len(set(names)):
|
||||||
raise ValueError("Expected names to all be unique: %s" % str(names))
|
raise ValueError("Expected names to all be unique: %s" % str(names))
|
||||||
return [op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
|
return [
|
||||||
for n, t in args]
|
op_def_pb2.OpDef.ArgDef(type=t.as_datatype_enum, name=n)
|
||||||
|
for n, t in args
|
||||||
|
]
|
||||||
|
|
||||||
self._sig.input_arg.extend(_to_argdef_list(inputs))
|
self._sig.input_arg.extend(_to_argdef_list(inputs))
|
||||||
self._sig.output_arg.extend(_to_argdef_list(outputs))
|
self._sig.output_arg.extend(_to_argdef_list(outputs))
|
||||||
|
@ -1106,16 +1106,18 @@ class BinaryOpTest(test.TestCase):
|
|||||||
|
|
||||||
def testAtan2SpecialValues(self):
|
def testAtan2SpecialValues(self):
|
||||||
x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
|
x1l, x2l = zip((+0.0, +0.0), (+0.0, -0.0), (-0.0, +0.0), (-0.0, -0.0),
|
||||||
(1.2345, float('inf')), (1.2345, -float('inf')),
|
(1.2345, float("inf")), (1.2345, -float("inf")),
|
||||||
(-4.321, float('inf')), (-4.125, -float('inf')),
|
(-4.321, float("inf")), (-4.125, -float("inf")),
|
||||||
(float('inf'), float('inf')), (float('inf'), -float('inf')),
|
(float("inf"), float("inf")), (float("inf"), -float("inf")),
|
||||||
(-float('inf'), float('inf')), (-float('inf'), -float('inf')))
|
(-float("inf"), float("inf")), (-float("inf"),
|
||||||
|
-float("inf")))
|
||||||
for dtype in np.float32, np.float64:
|
for dtype in np.float32, np.float64:
|
||||||
x1 = np.array(x1l).astype(dtype)
|
x1 = np.array(x1l).astype(dtype)
|
||||||
x2 = np.array(x2l).astype(dtype)
|
x2 = np.array(x2l).astype(dtype)
|
||||||
self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
|
self._compareCpu(x1, x2, np.arctan2, math_ops.atan2)
|
||||||
self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
|
self._compareGpu(x1, x2, np.arctan2, math_ops.atan2)
|
||||||
|
|
||||||
|
|
||||||
class ComparisonOpTest(test.TestCase):
|
class ComparisonOpTest(test.TestCase):
|
||||||
|
|
||||||
def _compareScalar(self, func, x, y, dtype):
|
def _compareScalar(self, func, x, y, dtype):
|
||||||
|
@ -19,58 +19,65 @@ from __future__ import print_function
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
from tensorflow import Tensor
|
|
||||||
from tensorflow import register_tensor_conversion_function
|
|
||||||
from tensorflow.python.framework import ops
|
from tensorflow.python.framework import ops
|
||||||
from tensorflow.python.ops import math_ops
|
|
||||||
from tensorflow.python.platform import test as test_lib
|
from tensorflow.python.platform import test as test_lib
|
||||||
|
|
||||||
|
|
||||||
class TensorPriorityTest(test_lib.TestCase):
|
class TensorPriorityTest(test_lib.TestCase):
|
||||||
|
|
||||||
def testSupportedRhsWithoutDelegation(self):
|
def testSupportedRhsWithoutDelegation(self):
|
||||||
|
|
||||||
class NumpyArraySubclass(np.ndarray):
|
class NumpyArraySubclass(np.ndarray):
|
||||||
pass
|
pass
|
||||||
supported_rhs_without_delegation = (
|
|
||||||
3,
|
supported_rhs_without_delegation = (3, 3.0, [1.0, 2.0], np.array(
|
||||||
3.0,
|
[1.0, 2.0]), NumpyArraySubclass(
|
||||||
[1.0, 2.0],
|
shape=(1, 2), buffer=np.array([1.0, 2.0])),
|
||||||
np.array([1.0, 2.0]),
|
ops.convert_to_tensor([[1.0, 2.0]]))
|
||||||
NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0])),
|
|
||||||
ops.convert_to_tensor([[1.0, 2.0]]))
|
|
||||||
for rhs in supported_rhs_without_delegation:
|
for rhs in supported_rhs_without_delegation:
|
||||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||||
res = tensor + rhs
|
res = tensor + rhs
|
||||||
self.assertIsInstance(res, Tensor)
|
self.assertIsInstance(res, ops.Tensor)
|
||||||
|
|
||||||
def testUnsupportedRhsWithoutDelegation(self):
|
def testUnsupportedRhsWithoutDelegation(self):
|
||||||
|
|
||||||
class WithoutReverseAdd(object):
|
class WithoutReverseAdd(object):
|
||||||
pass
|
pass
|
||||||
|
|
||||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||||
rhs = WithoutReverseAdd()
|
rhs = WithoutReverseAdd()
|
||||||
with self.assertRaisesWithPredicateMatch(
|
with self.assertRaisesWithPredicateMatch(
|
||||||
TypeError, lambda e: "Expected float" in str(e)):
|
TypeError, lambda e: "Expected float" in str(e)):
|
||||||
res = tensor + rhs
|
# pylint: disable=pointless-statement
|
||||||
|
tensor + rhs
|
||||||
|
|
||||||
def testUnsupportedRhsWithDelegation(self):
|
def testUnsupportedRhsWithDelegation(self):
|
||||||
|
|
||||||
class WithReverseAdd(object):
|
class WithReverseAdd(object):
|
||||||
|
|
||||||
def __radd__(self, lhs):
|
def __radd__(self, lhs):
|
||||||
return "Works!"
|
return "Works!"
|
||||||
|
|
||||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||||
rhs = WithReverseAdd()
|
rhs = WithReverseAdd()
|
||||||
res = tensor + rhs
|
res = tensor + rhs
|
||||||
self.assertEqual(res, "Works!")
|
self.assertEqual(res, "Works!")
|
||||||
|
|
||||||
def testFullDelegationControlUsingRegistry(self):
|
def testFullDelegationControlUsingRegistry(self):
|
||||||
|
|
||||||
class NumpyArraySubclass(np.ndarray):
|
class NumpyArraySubclass(np.ndarray):
|
||||||
|
|
||||||
def __radd__(self, lhs):
|
def __radd__(self, lhs):
|
||||||
return "Works!"
|
return "Works!"
|
||||||
|
|
||||||
def raise_to_delegate(value, dtype=None, name=None, as_ref=False):
|
def raise_to_delegate(value, dtype=None, name=None, as_ref=False):
|
||||||
|
del value, dtype, name, as_ref # Unused.
|
||||||
raise TypeError
|
raise TypeError
|
||||||
register_tensor_conversion_function(NumpyArraySubclass, raise_to_delegate,
|
|
||||||
priority=0)
|
ops.register_tensor_conversion_function(
|
||||||
|
NumpyArraySubclass, raise_to_delegate, priority=0)
|
||||||
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
tensor = ops.convert_to_tensor([[10.0, 20.0]])
|
||||||
rhs = NumpyArraySubclass(shape=(1,2), buffer=np.array([1.0, 2.0]))
|
rhs = NumpyArraySubclass(shape=(1, 2), buffer=np.array([1.0, 2.0]))
|
||||||
res = tensor + rhs
|
res = tensor + rhs
|
||||||
self.assertEqual(res, "Works!")
|
self.assertEqual(res, "Works!")
|
||||||
|
|
||||||
|
@ -1109,10 +1109,10 @@ class Conv2DTranspose(Conv2D):
|
|||||||
# Infer the static output shape:
|
# Infer the static output shape:
|
||||||
out_shape = inputs.get_shape().as_list()
|
out_shape = inputs.get_shape().as_list()
|
||||||
out_shape[c_axis] = self.filters
|
out_shape[c_axis] = self.filters
|
||||||
out_shape[h_axis] = utils.get_deconv_dim(
|
out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
|
||||||
out_shape[h_axis], stride_h, kernel_h, self.padding)
|
kernel_h, self.padding)
|
||||||
out_shape[w_axis] = utils.get_deconv_dim(
|
out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
|
||||||
out_shape[w_axis], stride_w, kernel_w, self.padding)
|
kernel_w, self.padding)
|
||||||
outputs.set_shape(out_shape)
|
outputs.set_shape(out_shape)
|
||||||
|
|
||||||
if self.bias:
|
if self.bias:
|
||||||
@ -1240,7 +1240,8 @@ class Conv3DTranspose(Conv3D):
|
|||||||
name: A string, the name of the layer.
|
name: A string, the name of the layer.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, filters,
|
def __init__(self,
|
||||||
|
filters,
|
||||||
kernel_size,
|
kernel_size,
|
||||||
strides=(1, 1, 1),
|
strides=(1, 1, 1),
|
||||||
padding='valid',
|
padding='valid',
|
||||||
@ -1269,12 +1270,13 @@ class Conv3DTranspose(Conv3D):
|
|||||||
bias_regularizer=bias_regularizer,
|
bias_regularizer=bias_regularizer,
|
||||||
activity_regularizer=activity_regularizer,
|
activity_regularizer=activity_regularizer,
|
||||||
trainable=trainable,
|
trainable=trainable,
|
||||||
name=name, **kwargs)
|
name=name,
|
||||||
|
**kwargs)
|
||||||
|
|
||||||
def build(self, input_shape):
|
def build(self, input_shape):
|
||||||
if len(input_shape) != 5:
|
if len(input_shape) != 5:
|
||||||
raise ValueError('Inputs should have rank 5, ' +
|
raise ValueError('Inputs should have rank 5, received input shape:',
|
||||||
'received input shape:', str(input_shape))
|
str(input_shape))
|
||||||
if self.data_format == 'channels_first':
|
if self.data_format == 'channels_first':
|
||||||
channel_axis = 1
|
channel_axis = 1
|
||||||
else:
|
else:
|
||||||
@ -1285,22 +1287,23 @@ class Conv3DTranspose(Conv3D):
|
|||||||
input_dim = input_shape[channel_axis]
|
input_dim = input_shape[channel_axis]
|
||||||
kernel_shape = self.kernel_size + (self.filters, input_dim)
|
kernel_shape = self.kernel_size + (self.filters, input_dim)
|
||||||
|
|
||||||
self.kernel = self.add_variable('kernel',
|
self.kernel = self.add_variable(
|
||||||
shape=kernel_shape,
|
'kernel',
|
||||||
initializer=self.kernel_initializer,
|
shape=kernel_shape,
|
||||||
regularizer=self.kernel_regularizer,
|
initializer=self.kernel_initializer,
|
||||||
trainable=True,
|
regularizer=self.kernel_regularizer,
|
||||||
dtype=self.dtype)
|
trainable=True,
|
||||||
|
dtype=self.dtype)
|
||||||
if self.use_bias:
|
if self.use_bias:
|
||||||
self.bias = self.add_variable('bias',
|
self.bias = self.add_variable(
|
||||||
shape=(self.filters,),
|
'bias',
|
||||||
initializer=self.bias_initializer,
|
shape=(self.filters,),
|
||||||
regularizer=self.bias_regularizer,
|
initializer=self.bias_initializer,
|
||||||
trainable=True,
|
regularizer=self.bias_regularizer,
|
||||||
dtype=self.dtype)
|
trainable=True,
|
||||||
|
dtype=self.dtype)
|
||||||
else:
|
else:
|
||||||
self.bias = None
|
self.bias = None
|
||||||
self.built = True
|
|
||||||
|
|
||||||
def call(self, inputs):
|
def call(self, inputs):
|
||||||
inputs_shape = array_ops.shape(inputs)
|
inputs_shape = array_ops.shape(inputs)
|
||||||
@ -1343,26 +1346,26 @@ class Conv3DTranspose(Conv3D):
|
|||||||
# Infer the static output shape:
|
# Infer the static output shape:
|
||||||
out_shape = inputs.get_shape().as_list()
|
out_shape = inputs.get_shape().as_list()
|
||||||
out_shape[c_axis] = self.filters
|
out_shape[c_axis] = self.filters
|
||||||
out_shape[d_axis] = utils.get_deconv_dim(
|
out_shape[d_axis] = utils.get_deconv_dim(out_shape[d_axis], stride_d,
|
||||||
out_shape[d_axis], stride_d, kernel_d, self.padding)
|
kernel_d, self.padding)
|
||||||
out_shape[h_axis] = utils.get_deconv_dim(
|
out_shape[h_axis] = utils.get_deconv_dim(out_shape[h_axis], stride_h,
|
||||||
out_shape[h_axis], stride_h, kernel_h, self.padding)
|
kernel_h, self.padding)
|
||||||
out_shape[w_axis] = utils.get_deconv_dim(
|
out_shape[w_axis] = utils.get_deconv_dim(out_shape[w_axis], stride_w,
|
||||||
out_shape[w_axis], stride_w, kernel_w, self.padding)
|
kernel_w, self.padding)
|
||||||
outputs.set_shape(out_shape)
|
outputs.set_shape(out_shape)
|
||||||
|
|
||||||
if self.bias:
|
if self.bias:
|
||||||
outputs_shape = outputs.shape.as_list()
|
outputs_shape = outputs.shape.as_list()
|
||||||
if self.data_format == 'channels_first':
|
if self.data_format == 'channels_first':
|
||||||
outputs_4d = array_ops.reshape(outputs,
|
outputs_4d = array_ops.reshape(outputs, [
|
||||||
[outputs_shape[0], outputs_shape[1],
|
outputs_shape[0], outputs_shape[1],
|
||||||
outputs_shape[2] * outputs_shape[3],
|
outputs_shape[2] * outputs_shape[3], outputs_shape[4]
|
||||||
outputs_shape[4]])
|
])
|
||||||
else:
|
else:
|
||||||
outputs_4d = array_ops.reshape(outputs,
|
outputs_4d = array_ops.reshape(outputs, [
|
||||||
[outputs_shape[0],
|
outputs_shape[0], outputs_shape[1] * outputs_shape[2],
|
||||||
outputs_shape[1] * outputs_shape[2],
|
outputs_shape[3], outputs_shape[4]
|
||||||
outputs_shape[3], outputs_shape[4]])
|
])
|
||||||
outputs_4d = nn.bias_add(
|
outputs_4d = nn.bias_add(
|
||||||
outputs_4d,
|
outputs_4d,
|
||||||
self.bias,
|
self.bias,
|
||||||
|
@ -715,8 +715,8 @@ class Conv3DTransposeTest(test.TestCase):
|
|||||||
layer = conv_layers.Conv3DTranspose(
|
layer = conv_layers.Conv3DTranspose(
|
||||||
32, volumes.get_shape()[1:4], padding='same')
|
32, volumes.get_shape()[1:4], padding='same')
|
||||||
output = layer.apply(volumes)
|
output = layer.apply(volumes)
|
||||||
self.assertListEqual(output.get_shape().as_list(), [5, depth, height,
|
self.assertListEqual(output.get_shape().as_list(),
|
||||||
width, 32])
|
[5, depth, height, width, 32])
|
||||||
|
|
||||||
def testCreateConv3DTransposeWithStrides(self):
|
def testCreateConv3DTransposeWithStrides(self):
|
||||||
depth, height, width = 4, 6, 8
|
depth, height, width = 4, 6, 8
|
||||||
@ -729,8 +729,7 @@ class Conv3DTransposeTest(test.TestCase):
|
|||||||
[5, depth * 2, height * 2, width * 2, 4])
|
[5, depth * 2, height * 2, width * 2, 4])
|
||||||
|
|
||||||
# Test strides integer.
|
# Test strides integer.
|
||||||
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2,
|
layer = conv_layers.Conv3DTranspose(4, [3, 3, 3], strides=2, padding='same')
|
||||||
padding='same')
|
|
||||||
output = layer.apply(volumes)
|
output = layer.apply(volumes)
|
||||||
self.assertListEqual(output.get_shape().as_list(),
|
self.assertListEqual(output.get_shape().as_list(),
|
||||||
[5, depth * 2, height * 2, width * 2, 4])
|
[5, depth * 2, height * 2, width * 2, 4])
|
||||||
@ -779,14 +778,14 @@ class Conv3DTransposeTest(test.TestCase):
|
|||||||
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
|
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
|
||||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
||||||
self.assertEqual(len(variables.trainable_variables()), 2)
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
|
conv_layers.conv3d_transpose(
|
||||||
|
volumes, 4, [3, 3, 3], name='deconv1', reuse=True)
|
||||||
self.assertEqual(len(variables.trainable_variables()), 2)
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
|
|
||||||
def testFunctionalConv3DTransposeReuseFromScope(self):
|
def testFunctionalConv3DTransposeReuseFromScope(self):
|
||||||
with variable_scope.variable_scope('scope'):
|
with variable_scope.variable_scope('scope'):
|
||||||
depth, height, width = 5, 7, 9
|
depth, height, width = 5, 7, 9
|
||||||
volumes = random_ops.random_uniform((5, depth, height, width, 32),
|
volumes = random_ops.random_uniform((5, depth, height, width, 32), seed=1)
|
||||||
seed=1)
|
|
||||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
||||||
self.assertEqual(len(variables.trainable_variables()), 2)
|
self.assertEqual(len(variables.trainable_variables()), 2)
|
||||||
with variable_scope.variable_scope('scope', reuse=True):
|
with variable_scope.variable_scope('scope', reuse=True):
|
||||||
@ -798,8 +797,8 @@ class Conv3DTransposeTest(test.TestCase):
|
|||||||
with variable_scope.variable_scope(
|
with variable_scope.variable_scope(
|
||||||
'scope', initializer=init_ops.ones_initializer()):
|
'scope', initializer=init_ops.ones_initializer()):
|
||||||
depth, height, width = 5, 7, 9
|
depth, height, width = 5, 7, 9
|
||||||
volumes = random_ops.random_uniform((5, depth, height, width, 32),
|
volumes = random_ops.random_uniform(
|
||||||
seed=1)
|
(5, depth, height, width, 32), seed=1)
|
||||||
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
conv_layers.conv3d_transpose(volumes, 4, [3, 3, 3], name='deconv1')
|
||||||
weights = variables.trainable_variables()
|
weights = variables.trainable_variables()
|
||||||
# Check the names of weights in order.
|
# Check the names of weights in order.
|
||||||
|
@ -205,7 +205,8 @@ def ctc_greedy_decoder(inputs, sequence_length, merge_repeated=True):
|
|||||||
`decoded.shape`: Shape vector, size `(2)`.
|
`decoded.shape`: Shape vector, size `(2)`.
|
||||||
The shape values are: `[batch_size, max_decoded_length]`
|
The shape values are: `[batch_size, max_decoded_length]`
|
||||||
neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
|
neg_sum_logits: A `float` matrix `(batch_size x 1)` containing, for the
|
||||||
sequence found, the negative of the sum of the greatest logit at each timeframe.
|
sequence found, the negative of the sum of the greatest logit at each
|
||||||
|
timeframe.
|
||||||
"""
|
"""
|
||||||
outputs = gen_ctc_ops._ctc_greedy_decoder(
|
outputs = gen_ctc_ops._ctc_greedy_decoder(
|
||||||
inputs, sequence_length, merge_repeated=merge_repeated)
|
inputs, sequence_length, merge_repeated=merge_repeated)
|
||||||
|
@ -39,6 +39,7 @@ from tensorflow.python.framework import constant_op
|
|||||||
from tensorflow.python.framework import dtypes
|
from tensorflow.python.framework import dtypes
|
||||||
from tensorflow.python.ops import array_ops
|
from tensorflow.python.ops import array_ops
|
||||||
from tensorflow.python.ops import linalg_ops
|
from tensorflow.python.ops import linalg_ops
|
||||||
|
from tensorflow.python.ops import math_ops
|
||||||
from tensorflow.python.ops import random_ops
|
from tensorflow.python.ops import random_ops
|
||||||
from tensorflow.python.ops import math_ops
|
from tensorflow.python.ops import math_ops
|
||||||
|
|
||||||
|
@ -964,8 +964,12 @@ def atrous_conv2d(value, filters, rate, padding, name=None):
|
|||||||
ValueError: If input/output depth does not match `filters`' shape, or if
|
ValueError: If input/output depth does not match `filters`' shape, or if
|
||||||
padding is other than `'VALID'` or `'SAME'`.
|
padding is other than `'VALID'` or `'SAME'`.
|
||||||
"""
|
"""
|
||||||
return convolution(input=value, filter=filters, padding=padding,
|
return convolution(
|
||||||
dilation_rate=np.broadcast_to(rate, (2, )), name=name)
|
input=value,
|
||||||
|
filter=filters,
|
||||||
|
padding=padding,
|
||||||
|
dilation_rate=np.broadcast_to(rate, (2,)),
|
||||||
|
name=name)
|
||||||
|
|
||||||
|
|
||||||
def conv2d_transpose(value,
|
def conv2d_transpose(value,
|
||||||
@ -1231,8 +1235,8 @@ def conv3d_transpose(value,
|
|||||||
axis = 1 if data_format == "NCDHW" else 4
|
axis = 1 if data_format == "NCDHW" else 4
|
||||||
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
|
if not value.get_shape()[axis].is_compatible_with(filter.get_shape()[4]):
|
||||||
raise ValueError("input channels does not match filter's input channels, "
|
raise ValueError("input channels does not match filter's input channels, "
|
||||||
"{} != {}".format(value.get_shape()[axis], filter.get_shape(
|
"{} != {}".format(value.get_shape()[axis],
|
||||||
)[4]))
|
filter.get_shape()[4]))
|
||||||
|
|
||||||
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
|
output_shape_ = ops.convert_to_tensor(output_shape, name="output_shape")
|
||||||
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)):
|
if not output_shape_.get_shape().is_compatible_with(tensor_shape.vector(5)):
|
||||||
|
@ -195,46 +195,47 @@ def load(sess, tags, export_dir, **saver_kwargs):
|
|||||||
Raises:
|
Raises:
|
||||||
RuntimeError: MetaGraphDef associated with the tags cannot be found.
|
RuntimeError: MetaGraphDef associated with the tags cannot be found.
|
||||||
"""
|
"""
|
||||||
# Build the SavedModel protocol buffer and find the requested meta graph def.
|
with sess.graph.as_default():
|
||||||
saved_model = _parse_saved_model(export_dir)
|
# Build the SavedModel protocol buffer and find requested meta graph def.
|
||||||
found_match = False
|
saved_model = _parse_saved_model(export_dir)
|
||||||
for meta_graph_def in saved_model.meta_graphs:
|
found_match = False
|
||||||
if set(meta_graph_def.meta_info_def.tags) == set(tags):
|
for meta_graph_def in saved_model.meta_graphs:
|
||||||
meta_graph_def_to_load = meta_graph_def
|
if set(meta_graph_def.meta_info_def.tags) == set(tags):
|
||||||
found_match = True
|
meta_graph_def_to_load = meta_graph_def
|
||||||
break
|
found_match = True
|
||||||
|
break
|
||||||
|
|
||||||
if not found_match:
|
if not found_match:
|
||||||
raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
|
raise RuntimeError("MetaGraphDef associated with tags " + str(tags).strip(
|
||||||
"[]") + " could not be found in SavedModel")
|
"[]") + " could not be found in SavedModel")
|
||||||
|
|
||||||
# Build a saver by importing the meta graph def to load.
|
# Build a saver by importing the meta graph def to load.
|
||||||
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
|
saver = tf_saver.import_meta_graph(meta_graph_def_to_load, **saver_kwargs)
|
||||||
|
|
||||||
if saver:
|
if saver:
|
||||||
# Build the checkpoint path where the variables are located.
|
# Build the checkpoint path where the variables are located.
|
||||||
variables_path = os.path.join(
|
variables_path = os.path.join(
|
||||||
compat.as_bytes(export_dir),
|
compat.as_bytes(export_dir),
|
||||||
compat.as_bytes(constants.VARIABLES_DIRECTORY),
|
compat.as_bytes(constants.VARIABLES_DIRECTORY),
|
||||||
compat.as_bytes(constants.VARIABLES_FILENAME))
|
compat.as_bytes(constants.VARIABLES_FILENAME))
|
||||||
|
|
||||||
# Restore the variables using the built saver in the provided session.
|
# Restore the variables using the built saver in the provided session.
|
||||||
saver.restore(sess, variables_path)
|
saver.restore(sess, variables_path)
|
||||||
else:
|
else:
|
||||||
tf_logging.info("The specified SavedModel has no variables; no "
|
tf_logging.info("The specified SavedModel has no variables; no "
|
||||||
"checkpoints were restored.")
|
"checkpoints were restored.")
|
||||||
|
|
||||||
# Get asset tensors, if any.
|
# Get asset tensors, if any.
|
||||||
asset_tensors_dictionary = _get_asset_tensors(export_dir,
|
asset_tensors_dictionary = _get_asset_tensors(export_dir,
|
||||||
meta_graph_def_to_load)
|
meta_graph_def_to_load)
|
||||||
|
|
||||||
main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
|
main_op_tensor = _get_main_op_tensor(meta_graph_def_to_load)
|
||||||
if main_op_tensor is not None:
|
if main_op_tensor is not None:
|
||||||
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
|
sess.run(fetches=[main_op_tensor], feed_dict=asset_tensors_dictionary)
|
||||||
else:
|
else:
|
||||||
legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
|
legacy_init_op_tensor = _get_legacy_init_op_tensor(meta_graph_def_to_load)
|
||||||
if legacy_init_op_tensor is not None:
|
if legacy_init_op_tensor is not None:
|
||||||
sess.run(fetches=[legacy_init_op_tensor],
|
sess.run(
|
||||||
feed_dict=asset_tensors_dictionary)
|
fetches=[legacy_init_op_tensor], feed_dict=asset_tensors_dictionary)
|
||||||
|
|
||||||
return meta_graph_def_to_load
|
return meta_graph_def_to_load
|
||||||
|
@ -151,6 +151,27 @@ class SavedModelTest(test.TestCase):
|
|||||||
constants.SAVED_MODEL_FILENAME_PBTXT):
|
constants.SAVED_MODEL_FILENAME_PBTXT):
|
||||||
loader.load(sess, ["foo"], export_dir)
|
loader.load(sess, ["foo"], export_dir)
|
||||||
|
|
||||||
|
def testVerifySessionGraphUsage(self):
|
||||||
|
export_dir = os.path.join(test.get_temp_dir(),
|
||||||
|
"test_verify_session_graph_usage")
|
||||||
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
|
||||||
|
with self.test_session(graph=ops.Graph()) as sess:
|
||||||
|
self._init_and_validate_variable(sess, "v", 42)
|
||||||
|
builder.add_meta_graph_and_variables(sess, [tag_constants.TRAINING])
|
||||||
|
|
||||||
|
# Save the SavedModel to disk.
|
||||||
|
builder.save()
|
||||||
|
|
||||||
|
# Build a session and supply it to the load operation.
|
||||||
|
sess = session.Session(graph=ops.Graph())
|
||||||
|
loader.load(sess, [tag_constants.TRAINING], export_dir)
|
||||||
|
|
||||||
|
# Check the variable within the scope of the session and its graph.
|
||||||
|
with sess:
|
||||||
|
self.assertEqual(
|
||||||
|
42, ops.get_collection(ops.GraphKeys.GLOBAL_VARIABLES)[0].eval())
|
||||||
|
|
||||||
def testSequence(self):
|
def testSequence(self):
|
||||||
export_dir = os.path.join(test.get_temp_dir(), "test_sequence")
|
export_dir = os.path.join(test.get_temp_dir(), "test_sequence")
|
||||||
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
builder = saved_model_builder.SavedModelBuilder(export_dir)
|
||||||
|
@ -12,33 +12,39 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
# ================================
|
# ================================
|
||||||
|
"""Imports a protobuf model as a graph in Tensorboard."""
|
||||||
|
|
||||||
from __future__ import absolute_import
|
from __future__ import absolute_import
|
||||||
from __future__ import division
|
from __future__ import division
|
||||||
from __future__ import print_function
|
from __future__ import print_function
|
||||||
|
|
||||||
import tensorflow as tf
|
from tensorflow.core.framework import graph_pb2
|
||||||
|
from tensorflow.python.client import session
|
||||||
|
from tensorflow.python.framework import importer
|
||||||
|
from tensorflow.python.framework import ops
|
||||||
|
from tensorflow.python.platform import gfile
|
||||||
|
from tensorflow.python.summary import summary
|
||||||
|
|
||||||
|
|
||||||
def import_to_tensorboard(model_dir, log_dir):
|
def import_to_tensorboard(model_dir, log_dir):
|
||||||
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
|
"""View an imported protobuf model (`.pb` file) as a graph in Tensorboard.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
model_dir: The location of the protobuf (`pb`) model to visualize
|
model_dir: The location of the protobuf (`pb`) model to visualize
|
||||||
log_dir: The location for the Tensorboard log to begin visualisation from.
|
log_dir: The location for the Tensorboard log to begin visualisation from.
|
||||||
|
|
||||||
Usage:
|
Usage:
|
||||||
Call this function with your model location and desired log directory.
|
Call this function with your model location and desired log directory.
|
||||||
Launch Tensorboard by pointing it to the log directory.
|
Launch Tensorboard by pointing it to the log directory.
|
||||||
View your imported `.pb` model as a graph.
|
View your imported `.pb` model as a graph.
|
||||||
"""
|
"""
|
||||||
with tf.Session(graph=tf.Graph()) as sess:
|
with session.Session(graph=ops.Graph()) as sess:
|
||||||
with tf.gfile.FastGFile(model_dir, 'rb') as f:
|
with gfile.FastGFile(model_dir, "rb") as f:
|
||||||
graph_def = tf.GraphDef()
|
graph_def = graph_pb2.GraphDef()
|
||||||
graph_def.ParseFromString(f.read())
|
graph_def.ParseFromString(f.read())
|
||||||
g_in = tf.import_graph_def(graph_def)
|
importer.import_graph_def(graph_def)
|
||||||
|
|
||||||
pb_visual_writer = tf.summary.FileWriter(log_dir)
|
pb_visual_writer = summary.FileWriter(log_dir)
|
||||||
pb_visual_writer.add_graph(sess.graph)
|
pb_visual_writer.add_graph(sess.graph)
|
||||||
print("Model Imported. Visualize by running: "
|
print("Model Imported. Visualize by running: "
|
||||||
"> tensorboard --logdir={}".format(log_dir))
|
"> tensorboard --logdir={}".format(log_dir))
|
||||||
|
@ -504,7 +504,14 @@ def run(args):
|
|||||||
|
|
||||||
Args:
|
Args:
|
||||||
args: A namespace parsed from command line.
|
args: A namespace parsed from command line.
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
AttributeError: An error when neither --inputs nor --input_exprs is passed
|
||||||
|
to run command.
|
||||||
"""
|
"""
|
||||||
|
if not args.inputs and not args.input_exprs:
|
||||||
|
raise AttributeError(
|
||||||
|
'At least one of --inputs and --input_exprs must be required')
|
||||||
tensor_key_feed_dict = load_inputs_from_input_arg_string(
|
tensor_key_feed_dict = load_inputs_from_input_arg_string(
|
||||||
args.inputs, args.input_exprs)
|
args.inputs, args.input_exprs)
|
||||||
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
|
run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
|
||||||
@ -629,8 +636,6 @@ def create_parser():
|
|||||||
def main():
|
def main():
|
||||||
parser = create_parser()
|
parser = create_parser()
|
||||||
args = parser.parse_args()
|
args = parser.parse_args()
|
||||||
if not args.inputs and not args.input_exprs:
|
|
||||||
args.error('At least one of --inputs and --input_exprs is required')
|
|
||||||
args.func(args)
|
args.func(args)
|
||||||
|
|
||||||
|
|
||||||
|
@ -409,6 +409,16 @@ Method name is: tensorflow/serving/predict"""
|
|||||||
with self.assertRaises(RuntimeError):
|
with self.assertRaises(RuntimeError):
|
||||||
saved_model_cli.run(args)
|
saved_model_cli.run(args)
|
||||||
|
|
||||||
|
def testRunCommandInputNotGivenError(self):
|
||||||
|
self.parser = saved_model_cli.create_parser()
|
||||||
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||||
|
args = self.parser.parse_args([
|
||||||
|
'run', '--dir', base_path, '--tag_set', 'serve', '--signature_def',
|
||||||
|
'serving_default'
|
||||||
|
])
|
||||||
|
with self.assertRaises(AttributeError):
|
||||||
|
saved_model_cli.run(args)
|
||||||
|
|
||||||
def testRunCommandWithDebuggerEnabled(self):
|
def testRunCommandWithDebuggerEnabled(self):
|
||||||
self.parser = saved_model_cli.create_parser()
|
self.parser = saved_model_cli.create_parser()
|
||||||
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
base_path = test.test_src_dir_path(SAVED_MODEL_PATH)
|
||||||
|
@ -210,9 +210,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
|||||||
else:
|
else:
|
||||||
var_name = ",".join([v.name for v in var])
|
var_name = ",".join([v.name for v in var])
|
||||||
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
|
_set_variable_or_list_initializer(var, ckpt_file, tensor_name_in_ckpt)
|
||||||
logging.info("Initialize variable %s from checkpoint %s with %s" % (
|
logging.info("Initialize variable %s from checkpoint %s with %s",
|
||||||
var_name, ckpt_dir_or_file, tensor_name_in_ckpt
|
var_name, ckpt_dir_or_file, tensor_name_in_ckpt)
|
||||||
))
|
|
||||||
else:
|
else:
|
||||||
scopes = ""
|
scopes = ""
|
||||||
# TODO(vihanjain): Support list of 'current_var_or_name' here.
|
# TODO(vihanjain): Support list of 'current_var_or_name' here.
|
||||||
@ -250,9 +249,8 @@ def init_from_checkpoint(ckpt_dir_or_file, assignment_map):
|
|||||||
if var is None:
|
if var is None:
|
||||||
var = _collect_partitioned_variable(var_name, store_vars)
|
var = _collect_partitioned_variable(var_name, store_vars)
|
||||||
_set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
|
_set_variable_or_list_initializer(var, ckpt_file, full_tensor_name)
|
||||||
logging.info("Initialize variable %s from checkpoint %s with %s" % (
|
logging.info("Initialize variable %s from checkpoint %s with %s",
|
||||||
var_name, ckpt_dir_or_file, full_tensor_name
|
var_name, ckpt_dir_or_file, full_tensor_name)
|
||||||
))
|
|
||||||
|
|
||||||
|
|
||||||
def _get_checkpoint_filename(ckpt_dir_or_file):
|
def _get_checkpoint_filename(ckpt_dir_or_file):
|
||||||
|
@ -935,11 +935,11 @@ def get_checkpoint_state(checkpoint_dir, latest_filename=None):
|
|||||||
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
|
ckpt.all_model_checkpoint_paths[i] = os.path.join(checkpoint_dir, p)
|
||||||
except errors.OpError as e:
|
except errors.OpError as e:
|
||||||
# It's ok if the file cannot be read
|
# It's ok if the file cannot be read
|
||||||
logging.warning("%s: %s" % (type(e).__name__, e))
|
logging.warning("%s: %s", type(e).__name__, e)
|
||||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||||
return None
|
return None
|
||||||
except text_format.ParseError as e:
|
except text_format.ParseError as e:
|
||||||
logging.warning("%s: %s" % (type(e).__name__, e))
|
logging.warning("%s: %s", type(e).__name__, e)
|
||||||
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
logging.warning("%s: Checkpoint ignored", coord_checkpoint_filename)
|
||||||
return None
|
return None
|
||||||
finally:
|
finally:
|
||||||
|
@ -230,13 +230,15 @@ class TensorboardServerTest(test.TestCase):
|
|||||||
def testScalars(self):
|
def testScalars(self):
|
||||||
"""Test the format of /data/scalars."""
|
"""Test the format of /data/scalars."""
|
||||||
data = self._getJson('/data/scalars?run=run1&tag=simple_values')
|
data = self._getJson('/data/scalars?run=run1&tag=simple_values')
|
||||||
self.assertEqual(len(data),self._SCALAR_COUNT)
|
self.assertEqual(len(data), self._SCALAR_COUNT)
|
||||||
|
|
||||||
def testScalarsCsv(self):
|
def testScalarsCsv(self):
|
||||||
"""Test the csv format of /data/scalars."""
|
"""Test the csv format of /data/scalars."""
|
||||||
data = self._get('/data/scalars?run=run1&tag=simple_values&format=csv').read()
|
data = self._get(
|
||||||
|
'/data/scalars?run=run1&tag=simple_values&format=csv').read()
|
||||||
line_count = data.count('\n')
|
line_count = data.count('\n')
|
||||||
self.assertEqual(line_count,self._SCALAR_COUNT + 1) # include 1 more line for header
|
self.assertEqual(line_count,
|
||||||
|
self._SCALAR_COUNT + 1) # include 1 more line for header
|
||||||
|
|
||||||
def testHistograms(self):
|
def testHistograms(self):
|
||||||
"""Test the format of /data/histograms."""
|
"""Test the format of /data/histograms."""
|
||||||
|
63
tensorflow/tensorboard/components/tf_audio_dashboard/BUILD
Normal file
63
tensorflow/tensorboard/components/tf_audio_dashboard/BUILD
Normal file
@ -0,0 +1,63 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
webfiles(
|
||||||
|
name = "tf_audio_dashboard",
|
||||||
|
srcs = [
|
||||||
|
"tf-audio-dashboard.html",
|
||||||
|
"tf-audio-grid.html",
|
||||||
|
"tf-audio-loader.html",
|
||||||
|
],
|
||||||
|
path = "/tf-audio-dashboard",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_backend",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||||
|
"@org_polymer",
|
||||||
|
"@org_polymer_paper_icon_button",
|
||||||
|
"@org_polymer_paper_slider",
|
||||||
|
"@org_polymer_paper_spinner",
|
||||||
|
"@org_polymer_paper_styles",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# MARKED FOR DELETION
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "legacy",
|
||||||
|
srcs = [
|
||||||
|
"tf-audio-dashboard.html",
|
||||||
|
"tf-audio-grid.html",
|
||||||
|
"tf-audio-loader.html",
|
||||||
|
],
|
||||||
|
destdir = "tf-audio-dashboard",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:tf_imports",
|
||||||
|
"//tensorflow/tensorboard/components/tf_backend:legacy",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is needed: components/BUILD seeks a legacy_ts rule in this package.
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "legacy_ts",
|
||||||
|
srcs = [],
|
||||||
|
deps_mgmt = "off",
|
||||||
|
runtime = "nodejs",
|
||||||
|
deps = ["//tensorflow/tensorboard/components:common_deps"],
|
||||||
|
)
|
@ -0,0 +1,26 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# bazel run //third_party/tensorflow/tensorboard/components/tf_audio_dashboard/demo
|
||||||
|
webfiles(
|
||||||
|
name = "demo",
|
||||||
|
srcs = ["index.html"],
|
||||||
|
path = "/tf-audio-dashboard/demo",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_audio_dashboard",
|
||||||
|
"//tensorflow/tensorboard/components/tf_audio_dashboard/demo/data",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||||
|
"@org_polymer_iron_demo_helpers",
|
||||||
|
"@org_polymer_paper_styles",
|
||||||
|
"@org_polymer_webcomponentsjs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
@ -0,0 +1,17 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
webfiles(
|
||||||
|
name = "data",
|
||||||
|
srcs = glob(["*"]),
|
||||||
|
path = "/tf-audio-dashboard/demo/data",
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
@ -107,6 +107,8 @@ future for loading older clips.
|
|||||||
</template>
|
</template>
|
||||||
</template>
|
</template>
|
||||||
<script>
|
<script>
|
||||||
|
"use strict";
|
||||||
|
|
||||||
Polymer({
|
Polymer({
|
||||||
is: "tf-audio-loader",
|
is: "tf-audio-loader",
|
||||||
properties: {
|
properties: {
|
||||||
|
81
tensorflow/tensorboard/components/tf_backend/BUILD
Normal file
81
tensorflow/tensorboard/components/tf_backend/BUILD
Normal file
@ -0,0 +1,81 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# TODO(dandelion): Add webfiles support for the test code.
|
||||||
|
|
||||||
|
webfiles(
|
||||||
|
name = "tf_backend",
|
||||||
|
srcs = [
|
||||||
|
"tf-backend.html",
|
||||||
|
":ts",
|
||||||
|
],
|
||||||
|
path = "/tf-backend",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting",
|
||||||
|
"@org_polymer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_typescript_genrule(
|
||||||
|
name = "ts",
|
||||||
|
srcs = [
|
||||||
|
"backend.ts",
|
||||||
|
"behavior.ts",
|
||||||
|
"requestManager.ts",
|
||||||
|
"router.ts",
|
||||||
|
"urlPathHelpers.ts",
|
||||||
|
],
|
||||||
|
typings = [
|
||||||
|
"@org_definitelytyped//:d3.d.ts",
|
||||||
|
"@org_definitelytyped//:lodash.d.ts",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting:ts_typings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# MARKED FOR DELETION
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "legacy",
|
||||||
|
srcs = [
|
||||||
|
"tf-backend.html",
|
||||||
|
":legacy_ts",
|
||||||
|
],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
destdir = "tf-backend",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:tf_imports",
|
||||||
|
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "legacy_ts",
|
||||||
|
srcs = [
|
||||||
|
"backend.ts",
|
||||||
|
"behavior.ts",
|
||||||
|
"requestManager.ts",
|
||||||
|
"router.ts",
|
||||||
|
"urlPathHelpers.ts",
|
||||||
|
],
|
||||||
|
deps_mgmt = "off",
|
||||||
|
runtime = "nodejs",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:common_deps",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
|
||||||
|
],
|
||||||
|
)
|
45
tensorflow/tensorboard/components/tf_backend_d3v4/BUILD
Normal file
45
tensorflow/tensorboard/components/tf_backend_d3v4/BUILD
Normal file
@ -0,0 +1,45 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "ts",
|
||||||
|
srcs = [
|
||||||
|
"backend.ts",
|
||||||
|
"behavior.ts",
|
||||||
|
"requestManager.ts",
|
||||||
|
"router.ts",
|
||||||
|
"urlPathHelpers.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
|
||||||
|
"//third_party/javascript/node_modules/typescript:es2015.promise",
|
||||||
|
"//third_party/javascript/typings/chai",
|
||||||
|
"//third_party/javascript/typings/d3_v4:bundle",
|
||||||
|
"//third_party/javascript/typings/lodash",
|
||||||
|
"//third_party/javascript/typings/mocha",
|
||||||
|
"//third_party/javascript/typings/polymer:polymer_without_externs",
|
||||||
|
"//third_party/javascript/typings/sinon",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# TODO(dandelion): Add runners for these tests
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "tests",
|
||||||
|
srcs = [
|
||||||
|
"backendTests.ts",
|
||||||
|
"behaviorTests.ts",
|
||||||
|
"requestManagerTests.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
65
tensorflow/tensorboard/components/tf_color_scale/BUILD
Normal file
65
tensorflow/tensorboard/components/tf_color_scale/BUILD
Normal file
@ -0,0 +1,65 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# TODO(dandelion): Add webfiles support for the test code.
|
||||||
|
|
||||||
|
webfiles(
|
||||||
|
name = "tf_color_scale",
|
||||||
|
srcs = [
|
||||||
|
"tf-color-scale.html",
|
||||||
|
":ts",
|
||||||
|
],
|
||||||
|
path = "/tf-color-scale",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||||
|
"@org_polymer",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_typescript_genrule(
|
||||||
|
name = "ts",
|
||||||
|
srcs = [
|
||||||
|
"colorScale.ts",
|
||||||
|
"palettes.ts",
|
||||||
|
],
|
||||||
|
typings = ["@org_definitelytyped//:d3.d.ts"],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# MARKED FOR DELETION
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "legacy",
|
||||||
|
srcs = [
|
||||||
|
"tf-color-scale.html",
|
||||||
|
":legacy_ts",
|
||||||
|
],
|
||||||
|
destdir = "tf-color-scale",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:tf_imports",
|
||||||
|
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "legacy_ts",
|
||||||
|
srcs = [
|
||||||
|
"colorScale.ts",
|
||||||
|
"palettes.ts",
|
||||||
|
],
|
||||||
|
deps_mgmt = "off",
|
||||||
|
runtime = "nodejs",
|
||||||
|
deps = ["//tensorflow/tensorboard/components:common_deps"],
|
||||||
|
)
|
26
tensorflow/tensorboard/components/tf_color_scale/demo/BUILD
Normal file
26
tensorflow/tensorboard/components/tf_color_scale/demo/BUILD
Normal file
@ -0,0 +1,26 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# bazel run //third_party/tensorflow/tensorboard/components/tf_color_scale/demo
|
||||||
|
webfiles(
|
||||||
|
name = "demo",
|
||||||
|
srcs = ["index.html"],
|
||||||
|
path = "/tf-color-scale/demo",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_color_scale",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||||
|
"@org_polymer_iron_demo_helpers",
|
||||||
|
"@org_polymer_paper_button",
|
||||||
|
"@org_polymer_paper_styles",
|
||||||
|
"@org_polymer_webcomponentsjs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
72
tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD
Normal file
72
tensorflow/tensorboard/components/tf_color_scale_d3v4/BUILD
Normal file
@ -0,0 +1,72 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//tensorflow/tensorboard:defs.bzl",
|
||||||
|
"tensorboard_ts_development_sources",
|
||||||
|
"tensorboard_ts_devserver",
|
||||||
|
"tensorboard_ts_library",
|
||||||
|
"tensorboard_webcomponent_library",
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# TODO(dandelion): Add runner for the test code.
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "tf_color_scale",
|
||||||
|
srcs = ["tf-color-scale.html"],
|
||||||
|
ts_lib_deps = [":ts"],
|
||||||
|
deps = [
|
||||||
|
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "ts",
|
||||||
|
srcs = [
|
||||||
|
"colorScale.ts",
|
||||||
|
"palettes.ts",
|
||||||
|
],
|
||||||
|
deps = ["//tensorflow/tensorboard/components:common_deps_d3v4"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "tests",
|
||||||
|
srcs = ["colorScaleTests.ts"],
|
||||||
|
deps = [
|
||||||
|
":ts",
|
||||||
|
"//tensorflow/tensorboard/components:common_deps_d3v4",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "demo",
|
||||||
|
srcs = ["demo.html"],
|
||||||
|
visibility = ["//visibility:public"],
|
||||||
|
deps = [
|
||||||
|
":tf_color_scale",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_devserver(
|
||||||
|
name = "devserver",
|
||||||
|
manifest = ":dev_sources",
|
||||||
|
serving_path = "/demo_out/bundle.js",
|
||||||
|
static_files = [":demo"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_development_sources(
|
||||||
|
name = "dev_sources",
|
||||||
|
deps = [
|
||||||
|
":ts",
|
||||||
|
],
|
||||||
|
)
|
102
tensorflow/tensorboard/components/tf_dashboard_common/BUILD
Normal file
102
tensorflow/tensorboard/components/tf_dashboard_common/BUILD
Normal file
@ -0,0 +1,102 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_typescript_genrule")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
webfiles(
|
||||||
|
name = "tf_dashboard_common",
|
||||||
|
srcs = glob(["*.html"]) + [
|
||||||
|
":ts",
|
||||||
|
],
|
||||||
|
path = "/tf-dashboard-common",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:plottable",
|
||||||
|
"//tensorflow/tensorboard/components/tf_storage",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting",
|
||||||
|
"@org_polymer",
|
||||||
|
"@org_polymer_iron_ajax",
|
||||||
|
"@org_polymer_iron_collapse",
|
||||||
|
"@org_polymer_iron_icons",
|
||||||
|
"@org_polymer_paper_button",
|
||||||
|
"@org_polymer_paper_checkbox",
|
||||||
|
"@org_polymer_paper_dialog",
|
||||||
|
"@org_polymer_paper_dropdown_menu",
|
||||||
|
"@org_polymer_paper_icon_button",
|
||||||
|
"@org_polymer_paper_input",
|
||||||
|
"@org_polymer_paper_item",
|
||||||
|
"@org_polymer_paper_menu",
|
||||||
|
"@org_polymer_paper_slider",
|
||||||
|
"@org_polymer_paper_spinner",
|
||||||
|
"@org_polymer_paper_styles",
|
||||||
|
"@org_polymer_paper_toggle_button",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_typescript_genrule(
|
||||||
|
name = "ts",
|
||||||
|
srcs = [
|
||||||
|
"categorizer.ts",
|
||||||
|
"dashboard-behavior.ts",
|
||||||
|
"reload-behavior.ts",
|
||||||
|
],
|
||||||
|
typings = [
|
||||||
|
"@org_definitelytyped//:d3.d.ts",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting:ts_typings",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# MARKED FOR DELETION
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "legacy",
|
||||||
|
srcs = glob(["*.html"]) + [":legacy_ts"],
|
||||||
|
destdir = "tf-dashboard-common",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:tf_imports",
|
||||||
|
"//tensorflow/tensorboard/components/tf_storage:legacy",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting:legacy",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-ajax:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-collapse:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-icons:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-checkbox:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-dialog:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-input:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-item:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-menu:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-slider:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-spinner:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-toggle-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "legacy_ts",
|
||||||
|
srcs = [
|
||||||
|
"categorizer.ts",
|
||||||
|
"dashboard-behavior.ts",
|
||||||
|
"reload-behavior.ts",
|
||||||
|
],
|
||||||
|
deps_mgmt = "off",
|
||||||
|
runtime = "nodejs",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:common_deps",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting:legacy_ts",
|
||||||
|
],
|
||||||
|
)
|
@ -0,0 +1,31 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# bazel run //third_party/tensorflow/tensorboard/components/tf_dashboard_common/demo
|
||||||
|
webfiles(
|
||||||
|
name = "demo",
|
||||||
|
srcs = [
|
||||||
|
"tf-categorizer-demo.html",
|
||||||
|
"tf-collapsable-pane-demo.html",
|
||||||
|
"tf-multi-checkbox-demo.html",
|
||||||
|
"tf-regex-group-demo.html",
|
||||||
|
],
|
||||||
|
path = "/tf-dashboard-common/demo",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_color_scale",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||||
|
"@org_polymer_iron_flex_layout",
|
||||||
|
"@org_polymer_paper_styles",
|
||||||
|
"@org_polymer_webcomponentsjs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
@ -57,6 +57,8 @@ plugin is requred to implement two functions:
|
|||||||
</style>
|
</style>
|
||||||
</template>
|
</template>
|
||||||
<script>
|
<script>
|
||||||
|
"use strict";
|
||||||
|
|
||||||
Polymer({
|
Polymer({
|
||||||
is: "tf-chart-scaffold",
|
is: "tf-chart-scaffold",
|
||||||
properties: {
|
properties: {
|
||||||
|
114
tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD
Normal file
114
tensorflow/tensorboard/components/tf_dashboard_common_d3v4/BUILD
Normal file
@ -0,0 +1,114 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load(
|
||||||
|
"//tensorflow/tensorboard:defs.bzl",
|
||||||
|
"tensorboard_ts_development_sources",
|
||||||
|
"tensorboard_ts_devserver",
|
||||||
|
"tensorboard_ts_library",
|
||||||
|
"tensorboard_webcomponent_library",
|
||||||
|
)
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "tf_dashboard_common",
|
||||||
|
srcs = [
|
||||||
|
"dashboard-style.html",
|
||||||
|
"run-color-style.html",
|
||||||
|
"scrollbar-style.html",
|
||||||
|
"tensorboard-color.html",
|
||||||
|
"tf-categorizer.html",
|
||||||
|
"tf-collapsable-pane.html",
|
||||||
|
"tf-dashboard.html",
|
||||||
|
"tf-dashboard-layout.html",
|
||||||
|
"tf-downloader.html",
|
||||||
|
"tf-multi-checkbox.html",
|
||||||
|
"tf-no-data-warning.html",
|
||||||
|
"tf-option-selector.html",
|
||||||
|
"tf-panes-helper.html",
|
||||||
|
"tf-regex-group.html",
|
||||||
|
"tf-run-selector.html",
|
||||||
|
"tf-sidebar-helper.html",
|
||||||
|
],
|
||||||
|
ts_lib_deps = [":ts"],
|
||||||
|
deps = [
|
||||||
|
"//third_party/javascript/plottable/v3:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-ajax:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-collapse:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-icons:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-checkbox:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-dialog:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-dropdown-menu:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-input:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-item:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-menu:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-slider:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-spinner:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-toggle-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "ts",
|
||||||
|
srcs = [
|
||||||
|
"dashboard-behavior.ts",
|
||||||
|
"reload-behavior.ts",
|
||||||
|
"tf-categorizer.ts",
|
||||||
|
"tf-multi-checkbox.ts",
|
||||||
|
"tf-regex-group.ts",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:common_deps_d3v4",
|
||||||
|
"//tensorflow/tensorboard/components/tf_storage_d3v4:ts",
|
||||||
|
"//tensorflow/tensorboard/components/vz_sorting_d3v4:ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "tests",
|
||||||
|
srcs = ["tf-categorizer-tests.ts"],
|
||||||
|
deps = [
|
||||||
|
":ts",
|
||||||
|
"//tensorflow/tensorboard/components:common_deps_d3v4",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "demo",
|
||||||
|
srcs = [
|
||||||
|
"tf-categorizer-demo.html",
|
||||||
|
"tf-collapsable-pane-demo.html",
|
||||||
|
"tf-multi-checkbox-demo.html",
|
||||||
|
"tf-regex-group-demo.html",
|
||||||
|
],
|
||||||
|
deps = [
|
||||||
|
":tf_dashboard_common",
|
||||||
|
"//tensorflow/tensorboard/components/tf_color_scale_d3v4:tf_color_scale",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-demo-helpers:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_devserver(
|
||||||
|
name = "devserver",
|
||||||
|
manifest = ":dev_sources",
|
||||||
|
serving_path = "/demo_out/bundle.js",
|
||||||
|
static_files = [":demo"],
|
||||||
|
)
|
||||||
|
|
||||||
|
tensorboard_ts_development_sources(
|
||||||
|
name = "dev_sources",
|
||||||
|
deps = [
|
||||||
|
":ts",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
@ -55,6 +55,8 @@ plugin is requred to implement two functions:
|
|||||||
</style>
|
</style>
|
||||||
</template>
|
</template>
|
||||||
<script>
|
<script>
|
||||||
|
"use strict";
|
||||||
|
|
||||||
Polymer({
|
Polymer({
|
||||||
is: "tf-chart-scaffold",
|
is: "tf-chart-scaffold",
|
||||||
properties: {
|
properties: {
|
||||||
|
@ -0,0 +1,63 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_ts_library")
|
||||||
|
load("//tensorflow/tensorboard:defs.bzl", "tensorboard_webcomponent_library")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
webfiles(
|
||||||
|
name = "tf_distribution_dashboard",
|
||||||
|
srcs = [
|
||||||
|
"tf-distribution-dashboard.html",
|
||||||
|
],
|
||||||
|
path = "/tf-distribution-dashboard",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_backend",
|
||||||
|
"//tensorflow/tensorboard/components/tf_color_scale",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:lodash",
|
||||||
|
"//tensorflow/tensorboard/components/vz_distribution_chart",
|
||||||
|
"@org_polymer",
|
||||||
|
"@org_polymer_iron_collapse",
|
||||||
|
"@org_polymer_paper_icon_button",
|
||||||
|
"@org_polymer_paper_styles",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
||||||
|
|
||||||
|
################################################################################
|
||||||
|
# MARKED FOR DELETION
|
||||||
|
|
||||||
|
tensorboard_webcomponent_library(
|
||||||
|
name = "legacy",
|
||||||
|
srcs = [
|
||||||
|
"tf-distribution-dashboard.html",
|
||||||
|
":legacy_ts",
|
||||||
|
],
|
||||||
|
destdir = "tf-distribution-dashboard",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components:tf_imports",
|
||||||
|
"//tensorflow/tensorboard/components/tf_backend:legacy",
|
||||||
|
"//tensorflow/tensorboard/components/tf_dashboard_common:legacy",
|
||||||
|
"//tensorflow/tensorboard/components/vz_distribution_chart:legacy",
|
||||||
|
"//third_party/javascript/polymer/v1/iron-collapse:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-icon-button:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/paper-styles:lib",
|
||||||
|
"//third_party/javascript/polymer/v1/polymer:lib",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is needed: components/BUILD seeks a legacy_ts rule in this package.
|
||||||
|
tensorboard_ts_library(
|
||||||
|
name = "legacy_ts",
|
||||||
|
srcs = [],
|
||||||
|
deps_mgmt = "off",
|
||||||
|
runtime = "nodejs",
|
||||||
|
deps = ["//tensorflow/tensorboard/components:common_deps"],
|
||||||
|
)
|
@ -0,0 +1,26 @@
|
|||||||
|
package(default_visibility = ["//tensorflow:internal"])
|
||||||
|
|
||||||
|
load("@io_bazel_rules_closure//closure:defs.bzl", "webfiles")
|
||||||
|
|
||||||
|
licenses(["notice"]) # Apache 2.0
|
||||||
|
|
||||||
|
# bazel run //third_party/tensorflow/tensorboard/components/tf_distribution_dashboard/demo
|
||||||
|
webfiles(
|
||||||
|
name = "demo",
|
||||||
|
srcs = ["index.html"],
|
||||||
|
path = "/tf-distribution-dashboard/demo",
|
||||||
|
deps = [
|
||||||
|
"//tensorflow/tensorboard/components/tf_distribution_dashboard",
|
||||||
|
"//tensorflow/tensorboard/components/tf_distribution_dashboard/demo/data",
|
||||||
|
"//tensorflow/tensorboard/components/tf_imports:d3",
|
||||||
|
"@org_polymer_iron_demo_helpers",
|
||||||
|
"@org_polymer_paper_styles",
|
||||||
|
"@org_polymer_webcomponentsjs",
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
filegroup(
|
||||||
|
name = "all_files",
|
||||||
|
srcs = glob(["**"]),
|
||||||
|
tags = ["notsan"],
|
||||||
|
)
|
Some files were not shown because too many files have changed in this diff Show More
Loading…
Reference in New Issue
Block a user