Start moving scatter plot methods out of vz-projector and into the scatter plot

adapter. Add DataSet reference to Projection class. The projector adapter now listens to the the distance metric changed event, as well as creates + owns scatter plot.
Change: 139496759
This commit is contained in:
Charles Nicholson 2016-11-17 13:06:08 -08:00 committed by TensorFlower Gardener
parent 839ee165dc
commit 815fa1b32d
7 changed files with 274 additions and 183 deletions

View File

@ -415,7 +415,8 @@ export type ProjectionType = 'tsne' | 'pca' | 'custom';
export class Projection { export class Projection {
constructor( constructor(
public projectionType: ProjectionType, public projectionType: ProjectionType,
public pointAccessors: PointAccessors3D, public dimensionality: number) {} public pointAccessors: PointAccessors3D, public dimensionality: number,
public dataSet: DataSet) {}
} }
export interface ColorOption { export interface ColorOption {

View File

@ -13,15 +13,16 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
import {DataSet, DistanceFunction} from './data'; import {DistanceFunction, Projection} from './data';
import {NearestEntry} from './knn'; import {NearestEntry} from './knn';
export type HoverListener = (index: number) => void; export type HoverListener = (index: number) => void;
export type SelectionChangedListener = export type SelectionChangedListener =
(selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[]) => (selectedPointIndices: number[], neighborsOfFirstPoint: NearestEntry[]) =>
void; void;
export type ProjectionChangedListener = (dataSet: DataSet) => void; export type ProjectionChangedListener = (projection: Projection) => void;
export type DistanceMetricChangedListener =
(distanceMetric: DistanceFunction) => void;
export interface ProjectorEventContext { export interface ProjectorEventContext {
/** Register a callback to be invoked when the mouse hovers over a point. */ /** Register a callback to be invoked when the mouse hovers over a point. */
registerHoverListener(listener: HoverListener); registerHoverListener(listener: HoverListener);
@ -37,6 +38,8 @@ export interface ProjectorEventContext {
/** Registers a callback to be invoked when the projection changes. */ /** Registers a callback to be invoked when the projection changes. */
registerProjectionChangedListener(listener: ProjectionChangedListener); registerProjectionChangedListener(listener: ProjectionChangedListener);
/** Notify listeners that a reprojection occurred. */ /** Notify listeners that a reprojection occurred. */
notifyProjectionChanged(dataSet: DataSet); notifyProjectionChanged(projection: Projection);
registerDistanceMetricChangedListener(listener:
DistanceMetricChangedListener);
notifyDistanceMetricChanged(distMetric: DistanceFunction); notifyDistanceMetricChanged(distMetric: DistanceFunction);
} }

View File

@ -13,9 +13,15 @@ See the License for the specific language governing permissions and
limitations under the License. limitations under the License.
==============================================================================*/ ==============================================================================*/
import {DataSet, DistanceFunction, PointAccessors3D} from './data'; import {DataSet, DistanceFunction, PointAccessors3D, Projection, State} from './data';
import {NearestEntry} from './knn'; import {NearestEntry} from './knn';
import {ProjectorEventContext} from './projectorEventContext';
import {LabelRenderParams} from './renderContext'; import {LabelRenderParams} from './renderContext';
import {ScatterPlot} from './scatterPlot';
import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels';
import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels';
import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites';
import {ScatterPlotVisualizerTraces} from './scatterPlotVisualizerTraces';
import * as vector from './vector'; import * as vector from './vector';
const LABEL_FONT_SIZE = 10; const LABEL_FONT_SIZE = 10;
@ -69,6 +75,119 @@ const NN_COLOR_SCALE =
* to use the ScatterPlot to render the current projected data set. * to use the ScatterPlot to render the current projected data set.
*/ */
export class ProjectorScatterPlotAdapter { export class ProjectorScatterPlotAdapter {
public scatterPlot: ScatterPlot;
private scatterPlotContainer: d3.Selection<any>;
private projection: Projection;
private hoverPointIndex: number;
private selectedPointIndices: number[];
private neighborsOfFirstSelectedPoint: NearestEntry[];
private renderLabelsIn3D: boolean = false;
private legendPointColorer: (index: number) => string;
private distanceMetric: DistanceFunction;
constructor(
scatterPlotContainer: d3.Selection<any>,
projectorEventContext: ProjectorEventContext) {
this.scatterPlot =
new ScatterPlot(scatterPlotContainer, projectorEventContext);
this.scatterPlotContainer = scatterPlotContainer;
projectorEventContext.registerProjectionChangedListener(projection => {
this.projection = projection;
this.updateScatterPlotWithNewProjection(projection);
});
projectorEventContext.registerSelectionChangedListener(
(selectedPointIndices, neighbors) => {
this.selectedPointIndices = selectedPointIndices;
this.neighborsOfFirstSelectedPoint = neighbors;
this.updateScatterPlotAttributes();
this.scatterPlot.render();
});
projectorEventContext.registerHoverListener(hoverPointIndex => {
this.hoverPointIndex = hoverPointIndex;
this.updateScatterPlotAttributes();
this.scatterPlot.render();
});
projectorEventContext.registerDistanceMetricChangedListener(
distanceMetric => {
this.distanceMetric = distanceMetric;
this.updateScatterPlotAttributes();
this.scatterPlot.render();
});
this.createVisualizers(false);
}
notifyProjectionPositionsUpdated() {
this.updateScatterPlotPositions();
this.scatterPlot.render();
}
set3DLabelMode(renderLabelsIn3D: boolean) {
this.renderLabelsIn3D = renderLabelsIn3D;
this.createVisualizers(renderLabelsIn3D);
this.updateScatterPlotAttributes();
this.scatterPlot.render();
}
setLegendPointColorer(legendPointColorer: (index: number) => string) {
this.legendPointColorer = legendPointColorer;
}
resize() {
this.scatterPlot.resize();
}
populateBookmarkFromUI(state: State) {
state.cameraDef = this.scatterPlot.getCameraDef();
}
restoreUIFromBookmark(state: State) {
this.scatterPlot.setCameraParametersForNextCameraCreation(
state.cameraDef, false);
}
updateScatterPlotPositions() {
const ds = (this.projection == null) ? null : this.projection.dataSet;
const accessors =
(this.projection == null) ? null : this.projection.pointAccessors;
const newPositions = this.generatePointPositionArray(ds, accessors);
this.scatterPlot.setPointPositions(ds, newPositions);
}
updateScatterPlotAttributes() {
if (this.projection == null) {
return;
}
const dataSet = this.projection.dataSet;
const selectedSet = this.selectedPointIndices;
const hoverIndex = this.hoverPointIndex;
const neighbors = this.neighborsOfFirstSelectedPoint;
const pointColorer = this.legendPointColorer;
const pointColors = this.generatePointColorArray(
dataSet, pointColorer, this.distanceMetric, selectedSet, neighbors,
hoverIndex, this.renderLabelsIn3D, this.getSpriteImageMode());
const pointScaleFactors = this.generatePointScaleFactorArray(
dataSet, selectedSet, neighbors, hoverIndex);
const labels = this.generateVisibleLabelRenderParams(
dataSet, selectedSet, neighbors, hoverIndex);
const traceColors = this.generateLineSegmentColorMap(dataSet, pointColorer);
const traceOpacities =
this.generateLineSegmentOpacityArray(dataSet, selectedSet);
const traceWidths =
this.generateLineSegmentWidthArray(dataSet, selectedSet);
this.scatterPlot.setPointColors(pointColors);
this.scatterPlot.setPointScaleFactors(pointScaleFactors);
this.scatterPlot.setLabels(labels);
this.scatterPlot.setTraceColors(traceColors);
this.scatterPlot.setTraceOpacities(traceOpacities);
this.scatterPlot.setTraceWidths(traceWidths);
}
render() {
this.scatterPlot.render();
}
generatePointPositionArray(ds: DataSet, pointAccessors: PointAccessors3D): generatePointPositionArray(ds: DataSet, pointAccessors: PointAccessors3D):
Float32Array { Float32Array {
if (ds == null) { if (ds == null) {
@ -116,19 +235,6 @@ export class ProjectorScatterPlotAdapter {
return positions; return positions;
} }
private packRgbIntoUint8Array(
rgbArray: Uint8Array, labelIndex: number, r: number, g: number,
b: number) {
rgbArray[labelIndex * 3] = r;
rgbArray[labelIndex * 3 + 1] = g;
rgbArray[labelIndex * 3 + 2] = b;
}
private styleRgbFromHexColor(hex: number): [number, number, number] {
const c = new THREE.Color(hex);
return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0];
}
generateVisibleLabelRenderParams( generateVisibleLabelRenderParams(
ds: DataSet, selectedPointIndices: number[], ds: DataSet, selectedPointIndices: number[],
neighborsOfFirstPoint: NearestEntry[], neighborsOfFirstPoint: NearestEntry[],
@ -155,11 +261,11 @@ export class ProjectorScatterPlotAdapter {
visibleLabels[dst] = hoverPointIndex; visibleLabels[dst] = hoverPointIndex;
scale[dst] = LABEL_SCALE_LARGE; scale[dst] = LABEL_SCALE_LARGE;
opacityFlags[dst] = 0; opacityFlags[dst] = 0;
const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER); const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_HOVER);
this.packRgbIntoUint8Array( packRgbIntoUint8Array(
fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER); const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_HOVER);
this.packRgbIntoUint8Array( packRgbIntoUint8Array(
strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[1]); strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[1]);
++dst; ++dst;
} }
@ -167,15 +273,15 @@ export class ProjectorScatterPlotAdapter {
// Selected points // Selected points
{ {
const n = selectedPointIndices.length; const n = selectedPointIndices.length;
const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED); const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_SELECTED);
const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED); const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_SELECTED);
for (let i = 0; i < n; ++i) { for (let i = 0; i < n; ++i) {
visibleLabels[dst] = selectedPointIndices[i]; visibleLabels[dst] = selectedPointIndices[i];
scale[dst] = LABEL_SCALE_LARGE; scale[dst] = LABEL_SCALE_LARGE;
opacityFlags[dst] = (n === 1) ? 0 : 1; opacityFlags[dst] = (n === 1) ? 0 : 1;
this.packRgbIntoUint8Array( packRgbIntoUint8Array(
fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
this.packRgbIntoUint8Array( packRgbIntoUint8Array(
strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]);
++dst; ++dst;
} }
@ -184,13 +290,13 @@ export class ProjectorScatterPlotAdapter {
// Neighbors // Neighbors
{ {
const n = neighborsOfFirstPoint.length; const n = neighborsOfFirstPoint.length;
const fillRgb = this.styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR); const fillRgb = styleRgbFromHexColor(LABEL_FILL_COLOR_NEIGHBOR);
const strokeRgb = this.styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR); const strokeRgb = styleRgbFromHexColor(LABEL_STROKE_COLOR_NEIGHBOR);
for (let i = 0; i < n; ++i) { for (let i = 0; i < n; ++i) {
visibleLabels[dst] = neighborsOfFirstPoint[i].index; visibleLabels[dst] = neighborsOfFirstPoint[i].index;
this.packRgbIntoUint8Array( packRgbIntoUint8Array(
fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]); fillColors, dst, fillRgb[0], fillRgb[1], fillRgb[2]);
this.packRgbIntoUint8Array( packRgbIntoUint8Array(
strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]); strokeColors, dst, strokeRgb[0], strokeRgb[1], strokeRgb[2]);
++dst; ++dst;
} }
@ -248,7 +354,6 @@ export class ProjectorScatterPlotAdapter {
for (let i = 0; i < ds.traces.length; i++) { for (let i = 0; i < ds.traces.length; i++) {
let dataTrace = ds.traces[i]; let dataTrace = ds.traces[i];
let colors = let colors =
new Float32Array(2 * (dataTrace.pointIndices.length - 1) * 3); new Float32Array(2 * (dataTrace.pointIndices.length - 1) * 3);
let colorIndex = 0; let colorIndex = 0;
@ -262,21 +367,19 @@ export class ProjectorScatterPlotAdapter {
colors[colorIndex++] = c1.r; colors[colorIndex++] = c1.r;
colors[colorIndex++] = c1.g; colors[colorIndex++] = c1.g;
colors[colorIndex++] = c1.b; colors[colorIndex++] = c1.b;
colors[colorIndex++] = c2.r; colors[colorIndex++] = c2.r;
colors[colorIndex++] = c2.g; colors[colorIndex++] = c2.g;
colors[colorIndex++] = c2.b; colors[colorIndex++] = c2.b;
} }
} else { } else {
for (let j = 0; j < dataTrace.pointIndices.length - 1; j++) { for (let j = 0; j < dataTrace.pointIndices.length - 1; j++) {
const c1 = this.getDefaultPointInTraceColor( const c1 =
j, dataTrace.pointIndices.length); getDefaultPointInTraceColor(j, dataTrace.pointIndices.length);
const c2 = this.getDefaultPointInTraceColor( const c2 =
j + 1, dataTrace.pointIndices.length); getDefaultPointInTraceColor(j + 1, dataTrace.pointIndices.length);
colors[colorIndex++] = c1.r; colors[colorIndex++] = c1.r;
colors[colorIndex++] = c1.g; colors[colorIndex++] = c1.g;
colors[colorIndex++] = c1.b; colors[colorIndex++] = c1.b;
colors[colorIndex++] = c2.r; colors[colorIndex++] = c2.r;
colors[colorIndex++] = c2.g; colors[colorIndex++] = c2.g;
colors[colorIndex++] = c2.b; colors[colorIndex++] = c2.b;
@ -319,15 +422,6 @@ export class ProjectorScatterPlotAdapter {
return widths; return widths;
} }
private getDefaultPointInTraceColor(index: number, totalPoints: number):
THREE.Color {
let hue = TRACE_START_HUE +
(TRACE_END_HUE - TRACE_START_HUE) * index / totalPoints;
let rgb = d3.hsl(hue, TRACE_SATURATION, TRACE_LIGHTNESS).rgb();
return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255);
}
generatePointColorArray( generatePointColorArray(
ds: DataSet, legendPointColorer: (index: number) => string, ds: DataSet, legendPointColorer: (index: number) => string,
distFunc: DistanceFunction, selectedPointIndices: number[], distFunc: DistanceFunction, selectedPointIndices: number[],
@ -419,6 +513,66 @@ export class ProjectorScatterPlotAdapter {
return colors; return colors;
} }
private updateScatterPlotWithNewProjection(projection: Projection) {
if (projection != null) {
this.scatterPlot.setDimensions(projection.dimensionality);
if (projection.dataSet.projectionCanBeRendered(
projection.projectionType)) {
this.updateScatterPlotAttributes();
this.notifyProjectionPositionsUpdated();
}
this.scatterPlot.setCameraParametersForNextCameraCreation(null, false);
} else {
this.updateScatterPlotAttributes();
this.notifyProjectionPositionsUpdated();
}
}
private createVisualizers(inLabels3DMode: boolean) {
const scatterPlot = this.scatterPlot;
scatterPlot.removeAllVisualizers();
if (inLabels3DMode) {
scatterPlot.addVisualizer(new ScatterPlotVisualizer3DLabels());
} else {
scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites());
scatterPlot.addVisualizer(
new ScatterPlotVisualizerCanvasLabels(this.scatterPlotContainer));
}
scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces());
}
private getSpriteImageMode(): boolean {
if (this.projection == null) {
return false;
}
const ds = this.projection.dataSet;
if ((ds == null) || (ds.spriteAndMetadataInfo == null)) {
return false;
}
return ds.spriteAndMetadataInfo.spriteImage != null;
}
}
function packRgbIntoUint8Array(
rgbArray: Uint8Array, labelIndex: number, r: number, g: number, b: number) {
rgbArray[labelIndex * 3] = r;
rgbArray[labelIndex * 3 + 1] = g;
rgbArray[labelIndex * 3 + 2] = b;
}
function styleRgbFromHexColor(hex: number): [number, number, number] {
const c = new THREE.Color(hex);
return [(c.r * 255) | 0, (c.g * 255) | 0, (c.b * 255) | 0];
}
function getDefaultPointInTraceColor(
index: number, totalPoints: number): THREE.Color {
let hue =
TRACE_START_HUE + (TRACE_END_HUE - TRACE_START_HUE) * index / totalPoints;
let rgb = d3.hsl(hue, TRACE_SATURATION, TRACE_LIGHTNESS).rgb();
return new THREE.Color(rgb.r / 255, rgb.g / 255, rgb.b / 255);
} }
/** /**

View File

@ -117,14 +117,12 @@ export class ScatterPlot {
private rectangleSelector: ScatterPlotRectangleSelector; private rectangleSelector: ScatterPlotRectangleSelector;
constructor( constructor(
container: d3.Selection<any>, labelAccessor: (index: number) => string, container: d3.Selection<any>,
projectorEventContext: ProjectorEventContext) { projectorEventContext: ProjectorEventContext) {
this.containerNode = container.node() as HTMLElement; this.containerNode = container.node() as HTMLElement;
this.projectorEventContext = projectorEventContext; this.projectorEventContext = projectorEventContext;
this.getLayoutValues(); this.getLayoutValues();
this.labelAccessor = labelAccessor;
this.scene = new THREE.Scene(); this.scene = new THREE.Scene();
this.renderer = this.renderer =
new THREE.WebGLRenderer({alpha: true, premultipliedAlpha: false}); new THREE.WebGLRenderer({alpha: true, premultipliedAlpha: false});
@ -457,7 +455,7 @@ export class ScatterPlot {
return this.dimensionality === 3; return this.dimensionality === 3;
} }
private remove3dAxis(): THREE.Object3D { private remove3dAxisFromScene(): THREE.Object3D {
const axes = this.scene.getObjectByName('axes'); const axes = this.scene.getObjectByName('axes');
if (axes != null) { if (axes != null) {
this.scene.remove(axes); this.scene.remove(axes);
@ -481,7 +479,7 @@ export class ScatterPlot {
const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality); const def = this.cameraDef || this.makeDefaultCameraDef(dimensionality);
this.recreateCamera(def); this.recreateCamera(def);
this.remove3dAxis(); this.remove3dAxisFromScene();
if (dimensionality === 3) { if (dimensionality === 3) {
this.add3dAxis(); this.add3dAxis();
} }
@ -624,10 +622,12 @@ export class ScatterPlot {
}); });
{ {
const axes = this.remove3dAxis(); const axes = this.remove3dAxisFromScene();
this.renderer.render(this.scene, this.camera, this.pickingTexture); this.renderer.render(this.scene, this.camera, this.pickingTexture);
if (axes != null) {
this.scene.add(axes); this.scene.add(axes);
} }
}
// Render second pass to color buffer, to be displayed on the canvas. // Render second pass to color buffer, to be displayed on the canvas.
this.visualizers.forEach(v => { this.visualizers.forEach(v => {

View File

@ -305,14 +305,15 @@ export class ScatterPlotVisualizerSprites implements ScatterPlotVisualizer {
onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) { onPointPositionsChanged(newPositions: Float32Array, dataSet: DataSet) {
if (this.points != null) { if (this.points != null) {
const notEnoughSpace = (this.pickingColors.length < newPositions.length); const notEnoughSpace = (this.pickingColors.length < newPositions.length);
const newImage = const newImage = (dataSet != null) &&
(this.image !== dataSet.spriteAndMetadataInfo.spriteImage); (this.image !== dataSet.spriteAndMetadataInfo.spriteImage);
if (notEnoughSpace || newImage) { if (notEnoughSpace || newImage) {
this.dispose(); this.dispose();
} }
} }
this.image = dataSet.spriteAndMetadataInfo.spriteImage; this.image =
(dataSet != null) ? dataSet.spriteAndMetadataInfo.spriteImage : null;
this.worldSpacePointPositions = newPositions; this.worldSpacePointPositions = newPositions;
if (this.points == null) { if (this.points == null) {

View File

@ -372,13 +372,14 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
const accessors = const accessors =
dataSet.getPointAccessors('tsne', [0, 1, this.tSNEis3d ? 2 : null]); dataSet.getPointAccessors('tsne', [0, 1, this.tSNEis3d ? 2 : null]);
const dimensionality = this.tSNEis3d ? 3 : 2; const dimensionality = this.tSNEis3d ? 3 : 2;
const projection = new Projection('tsne', accessors, dimensionality); const projection =
new Projection('tsne', accessors, dimensionality, dataSet);
this.projector.setProjection(projection); this.projector.setProjection(projection);
if (!this.dataSet.hasTSNERun) { if (!this.dataSet.hasTSNERun) {
this.runTSNE(); this.runTSNE();
} else { } else {
this.projector.notifyProjectionsUpdated(); this.projector.notifyProjectionPositionsUpdated();
} }
} }
@ -390,7 +391,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
(iteration: number) => { (iteration: number) => {
if (iteration != null) { if (iteration != null) {
this.iterationLabel.text(iteration); this.iterationLabel.text(iteration);
this.projector.notifyProjectionsUpdated(); this.projector.notifyProjectionPositionsUpdated();
} else { } else {
this.runTsneButton.attr('disabled', null); this.runTsneButton.attr('disabled', null);
this.stopTsneButton.attr('disabled', true); this.stopTsneButton.attr('disabled', true);
@ -426,7 +427,8 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
'pca', [this.pcaX, this.pcaY, this.pcaZ]); 'pca', [this.pcaX, this.pcaY, this.pcaZ]);
const dimensionality = this.pcaIs3d ? 3 : 2; const dimensionality = this.pcaIs3d ? 3 : 2;
const projection = new Projection('pca', accessors, dimensionality); const projection =
new Projection('pca', accessors, dimensionality, this.dataSet);
this.projector.setProjection(projection); this.projector.setProjection(projection);
let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]); let numComponents = Math.min(NUM_PCA_COMPONENTS, this.dataSet.dim[1]);
this.updateTotalVarianceMessage(); this.updateTotalVarianceMessage();
@ -454,7 +456,7 @@ export class ProjectionsPanel extends ProjectionsPanelPolymer {
this.dataSet.projectLinear(yDir, 'linear-y'); this.dataSet.projectLinear(yDir, 'linear-y');
const accessors = this.dataSet.getPointAccessors('custom', ['x', 'y']); const accessors = this.dataSet.getPointAccessors('custom', ['x', 'y']);
const projection = new Projection('custom', accessors, 2); const projection = new Projection('custom', accessors, 2, this.dataSet);
this.projector.setProjection(projection); this.projector.setProjection(projection);
} }

View File

@ -21,13 +21,9 @@ import {ProtoDataProvider} from './data-provider-proto';
import {ServerDataProvider} from './data-provider-server'; import {ServerDataProvider} from './data-provider-server';
import * as knn from './knn'; import * as knn from './knn';
import * as logging from './logging'; import * as logging from './logging';
import {HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext'; import {DistanceMetricChangedListener, HoverListener, ProjectionChangedListener, ProjectorEventContext, SelectionChangedListener} from './projectorEventContext';
import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter'; import {ProjectorScatterPlotAdapter} from './projectorScatterPlotAdapter';
import {Mode, ScatterPlot} from './scatterPlot'; import {Mode} from './scatterPlot';
import {ScatterPlotVisualizer3DLabels} from './scatterPlotVisualizer3DLabels';
import {ScatterPlotVisualizerCanvasLabels} from './scatterPlotVisualizerCanvasLabels';
import {ScatterPlotVisualizerSprites} from './scatterPlotVisualizerSprites';
import {ScatterPlotVisualizerTraces} from './scatterPlotVisualizerTraces';
import * as util from './util'; import * as util from './util';
import {BookmarkPanel} from './vz-projector-bookmark-panel'; import {BookmarkPanel} from './vz-projector-bookmark-panel';
import {DataPanel} from './vz-projector-data-panel'; import {DataPanel} from './vz-projector-data-panel';
@ -69,11 +65,11 @@ export class Projector extends ProjectorPolymer implements
private selectionChangedListeners: SelectionChangedListener[]; private selectionChangedListeners: SelectionChangedListener[];
private hoverListeners: HoverListener[]; private hoverListeners: HoverListener[];
private projectionChangedListeners: ProjectionChangedListener[]; private projectionChangedListeners: ProjectionChangedListener[];
private distanceMetricChangedListeners: DistanceMetricChangedListener[];
private originalDataSet: DataSet; private originalDataSet: DataSet;
private dom: d3.Selection<any>; private dom: d3.Selection<any>;
private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter; private projectorScatterPlotAdapter: ProjectorScatterPlotAdapter;
private scatterPlot: ScatterPlot;
private dim: number; private dim: number;
private dataSetFilterIndices: number[]; private dataSetFilterIndices: number[];
@ -108,6 +104,7 @@ export class Projector extends ProjectorPolymer implements
this.selectionChangedListeners = []; this.selectionChangedListeners = [];
this.hoverListeners = []; this.hoverListeners = [];
this.projectionChangedListeners = []; this.projectionChangedListeners = [];
this.distanceMetricChangedListeners = [];
this.selectedPointIndices = []; this.selectedPointIndices = [];
this.neighborsOfFirstPoint = []; this.neighborsOfFirstPoint = [];
this.dom = d3.select(this); this.dom = d3.select(this);
@ -133,14 +130,17 @@ export class Projector extends ProjectorPolymer implements
.metadata[this.selectedLabelOption] as string; .metadata[this.selectedLabelOption] as string;
}; };
this.metadataCard.setLabelOption(this.selectedLabelOption); this.metadataCard.setLabelOption(this.selectedLabelOption);
this.scatterPlot.setLabelAccessor(labelAccessor); this.projectorScatterPlotAdapter.scatterPlot.setLabelAccessor(
this.scatterPlot.render(); labelAccessor);
this.projectorScatterPlotAdapter.render();
} }
setSelectedColorOption(colorOption: ColorOption) { setSelectedColorOption(colorOption: ColorOption) {
this.selectedColorOption = colorOption; this.selectedColorOption = colorOption;
this.updateScatterPlotAttributes(); this.projectorScatterPlotAdapter.setLegendPointColorer(
this.scatterPlot.render(); this.getLegendPointColorer(colorOption));
this.projectorScatterPlotAdapter.updateScatterPlotAttributes();
this.projectorScatterPlotAdapter.render();
} }
setNormalizeData(normalizeData: boolean) { setNormalizeData(normalizeData: boolean) {
@ -153,8 +153,7 @@ export class Projector extends ProjectorPolymer implements
metadataFile?: string) { metadataFile?: string) {
this.dataSetFilterIndices = null; this.dataSetFilterIndices = null;
this.originalDataSet = ds; this.originalDataSet = ds;
if (this.scatterPlot == null || ds == null) { if (this.projectorScatterPlotAdapter == null || ds == null) {
// We are not ready yet.
return; return;
} }
this.normalizeData = this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE; this.normalizeData = this.originalDataSet.dim[1] >= THRESHOLD_DIM_NORMALIZE;
@ -176,8 +175,8 @@ export class Projector extends ProjectorPolymer implements
// height can grow indefinitely. // height can grow indefinitely.
let container = this.dom.select('#container'); let container = this.dom.select('#container');
container.style('height', container.property('clientHeight') + 'px'); container.style('height', container.property('clientHeight') + 'px');
this.scatterPlot.resize(); this.projectorScatterPlotAdapter.resize();
this.scatterPlot.render(); this.projectorScatterPlotAdapter.render();
} }
setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) { setSelectedTensor(run: string, tensorInfo: EmbeddingInfo) {
@ -203,7 +202,7 @@ export class Projector extends ProjectorPolymer implements
return this.dataSet.points[localIndex].index; return this.dataSet.points[localIndex].index;
}); });
this.setCurrentDataSet(this.originalDataSet.getSubset()); this.setCurrentDataSet(this.originalDataSet.getSubset());
this.updateScatterPlotPositions(); this.projectorScatterPlotAdapter.updateScatterPlotPositions();
this.dataSetFilterIndices = []; this.dataSetFilterIndices = [];
this.adjustSelectionAndHover(originalPointIndices); this.adjustSelectionAndHover(originalPointIndices);
} }
@ -247,8 +246,16 @@ export class Projector extends ProjectorPolymer implements
this.projectionChangedListeners.push(listener); this.projectionChangedListeners.push(listener);
} }
notifyProjectionChanged(dataSet: DataSet) { notifyProjectionChanged(projection: Projection) {
this.projectionChangedListeners.forEach(l => l(dataSet)); this.projectionChangedListeners.forEach(l => l(projection));
}
registerDistanceMetricChangedListener(l: DistanceMetricChangedListener) {
this.distanceMetricChangedListeners.push(l);
}
notifyDistanceMetricChanged(distMetric: DistanceFunction) {
this.distanceMetricChangedListeners.forEach(l => l(distMetric));
} }
_dataProtoChanged(dataProtoString: string) { _dataProtoChanged(dataProtoString: string) {
@ -324,11 +331,6 @@ export class Projector extends ProjectorPolymer implements
return (label3DModeButton as any).active; return (label3DModeButton as any).active;
} }
private getSpriteImageMode(): boolean {
return this.dataSet && this.dataSet.spriteAndMetadataInfo &&
this.dataSet.spriteAndMetadataInfo.spriteImage != null;
}
adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) { adjustSelectionAndHover(selectedPointIndices: number[], hoverIndex?: number) {
this.notifySelectionChanged(selectedPointIndices); this.notifySelectionChanged(selectedPointIndices);
this.notifyHoverOverPoint(hoverIndex); this.notifyHoverOverPoint(hoverIndex);
@ -338,8 +340,7 @@ export class Projector extends ProjectorPolymer implements
private setMode(mode: Mode) { private setMode(mode: Mode) {
let selectModeButton = this.querySelector('#selectMode'); let selectModeButton = this.querySelector('#selectMode');
(selectModeButton as any).active = (mode === Mode.SELECT); (selectModeButton as any).active = (mode === Mode.SELECT);
this.projectorScatterPlotAdapter.scatterPlot.setMode(mode);
this.scatterPlot.setMode(mode);
} }
private setCurrentDataSet(ds: DataSet) { private setCurrentDataSet(ds: DataSet) {
@ -360,14 +361,15 @@ export class Projector extends ProjectorPolymer implements
this.projectionsPanel.dataSetUpdated( this.projectionsPanel.dataSetUpdated(
this.dataSet, this.originalDataSet, this.dim); this.dataSet, this.originalDataSet, this.dim);
this.scatterPlot.setCameraParametersForNextCameraCreation(null, true); this.projectorScatterPlotAdapter.scatterPlot
.setCameraParametersForNextCameraCreation(null, true);
} }
private setupUIControls() { private setupUIControls() {
// View controls // View controls
this.querySelector('#reset-zoom').addEventListener('click', () => { this.querySelector('#reset-zoom').addEventListener('click', () => {
this.scatterPlot.resetZoom(); this.projectorScatterPlotAdapter.scatterPlot.resetZoom();
this.scatterPlot.startOrbitAnimation(); this.projectorScatterPlotAdapter.scatterPlot.startOrbitAnimation();
}); });
let selectModeButton = this.querySelector('#selectMode'); let selectModeButton = this.querySelector('#selectMode');
@ -376,14 +378,13 @@ export class Projector extends ProjectorPolymer implements
}); });
let nightModeButton = this.querySelector('#nightDayMode'); let nightModeButton = this.querySelector('#nightDayMode');
nightModeButton.addEventListener('click', () => { nightModeButton.addEventListener('click', () => {
this.scatterPlot.setDayNightMode((nightModeButton as any).active); this.projectorScatterPlotAdapter.scatterPlot.setDayNightMode(
(nightModeButton as any).active);
}); });
const labels3DModeButton = this.get3DLabelModeButton(); const labels3DModeButton = this.get3DLabelModeButton();
labels3DModeButton.addEventListener('click', () => { labels3DModeButton.addEventListener('click', () => {
this.createVisualizers(this.get3DLabelMode()); this.projectorScatterPlotAdapter.set3DLabelMode(this.get3DLabelMode());
this.updateScatterPlotAttributes();
this.scatterPlot.render();
}); });
window.addEventListener('resize', () => { window.addEventListener('resize', () => {
@ -391,18 +392,19 @@ export class Projector extends ProjectorPolymer implements
let parentHeight = let parentHeight =
(container.node().parentNode as HTMLElement).clientHeight; (container.node().parentNode as HTMLElement).clientHeight;
container.style('height', parentHeight + 'px'); container.style('height', parentHeight + 'px');
this.scatterPlot.resize(); this.projectorScatterPlotAdapter.resize();
}); });
this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter(); {
const labelAccessor = i =>
'' + this.dataSet.points[i].metadata[this.selectedLabelOption];
this.projectorScatterPlotAdapter = new ProjectorScatterPlotAdapter(
this.getScatterContainer(), this as ProjectorEventContext);
this.projectorScatterPlotAdapter.scatterPlot.setLabelAccessor(
labelAccessor);
}
this.scatterPlot = new ScatterPlot( this.projectorScatterPlotAdapter.scatterPlot.onCameraMove(
this.getScatterContainer(),
i => '' + this.dataSet.points[i].metadata[this.selectedLabelOption],
this as ProjectorEventContext);
this.createVisualizers(false);
this.scatterPlot.onCameraMove(
(cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) => (cameraPosition: THREE.Vector3, cameraTarget: THREE.Vector3) =>
this.bookmarkPanel.clearStateSelection()); this.bookmarkPanel.clearStateSelection());
@ -425,75 +427,16 @@ export class Projector extends ProjectorPolymer implements
hoverText = point.metadata[this.selectedLabelOption].toString(); hoverText = point.metadata[this.selectedLabelOption].toString();
} }
} }
this.updateScatterPlotAttributes();
this.scatterPlot.render();
if (this.selectedPointIndices.length === 0) { if (this.selectedPointIndices.length === 0) {
this.statusBar.style('display', hoverText ? null : 'none'); this.statusBar.style('display', hoverText ? null : 'none');
this.statusBar.text(hoverText); this.statusBar.text(hoverText);
} }
} }
private updateScatterPlotPositions() {
if (this.dataSet == null) {
return;
}
if (this.projection == null) {
return;
}
const newPositions =
this.projectorScatterPlotAdapter.generatePointPositionArray(
this.dataSet, this.projection.pointAccessors);
this.scatterPlot.setPointPositions(this.dataSet, newPositions);
}
private updateScatterPlotAttributes() {
const dataSet = this.dataSet;
const selectedSet = this.selectedPointIndices;
const hoverIndex = this.hoverPointIndex;
const neighbors = this.neighborsOfFirstPoint;
const pointColorer = this.getLegendPointColorer(this.selectedColorOption);
const adapter = this.projectorScatterPlotAdapter;
const pointColors = adapter.generatePointColorArray(
dataSet, pointColorer, this.inspectorPanel.distFunc, selectedSet,
neighbors, hoverIndex, this.get3DLabelMode(),
this.getSpriteImageMode());
const pointScaleFactors = adapter.generatePointScaleFactorArray(
dataSet, selectedSet, neighbors, hoverIndex);
const labels = adapter.generateVisibleLabelRenderParams(
dataSet, selectedSet, neighbors, hoverIndex);
const traceColors =
adapter.generateLineSegmentColorMap(dataSet, pointColorer);
const traceOpacities =
adapter.generateLineSegmentOpacityArray(dataSet, selectedSet);
const traceWidths =
adapter.generateLineSegmentWidthArray(dataSet, selectedSet);
this.scatterPlot.setPointColors(pointColors);
this.scatterPlot.setPointScaleFactors(pointScaleFactors);
this.scatterPlot.setLabels(labels);
this.scatterPlot.setTraceColors(traceColors);
this.scatterPlot.setTraceOpacities(traceOpacities);
this.scatterPlot.setTraceWidths(traceWidths);
}
private getScatterContainer(): d3.Selection<any> { private getScatterContainer(): d3.Selection<any> {
return this.dom.select('#scatter'); return this.dom.select('#scatter');
} }
private createVisualizers(inLabels3DMode: boolean) {
const scatterPlot = this.scatterPlot;
scatterPlot.removeAllVisualizers();
if (inLabels3DMode) {
scatterPlot.addVisualizer(new ScatterPlotVisualizer3DLabels());
} else {
scatterPlot.addVisualizer(new ScatterPlotVisualizerSprites());
scatterPlot.addVisualizer(
new ScatterPlotVisualizerCanvasLabels(this.getScatterContainer()));
}
scatterPlot.addVisualizer(new ScatterPlotVisualizerTraces());
}
private onSelectionChanged( private onSelectionChanged(
selectedPointIndices: number[], selectedPointIndices: number[],
neighborsOfFirstPoint: knn.NearestEntry[]) { neighborsOfFirstPoint: knn.NearestEntry[]) {
@ -503,26 +446,18 @@ export class Projector extends ProjectorPolymer implements
this.selectedPointIndices.length + neighborsOfFirstPoint.length; this.selectedPointIndices.length + neighborsOfFirstPoint.length;
this.statusBar.text(`Selected ${totalNumPoints} points`) this.statusBar.text(`Selected ${totalNumPoints} points`)
.style('display', totalNumPoints > 0 ? null : 'none'); .style('display', totalNumPoints > 0 ? null : 'none');
this.updateScatterPlotAttributes();
this.scatterPlot.render();
} }
setProjection(projection: Projection) { setProjection(projection: Projection) {
this.projection = projection; this.projection = projection;
this.scatterPlot.setDimensions(projection.dimensionality); if (projection != null) {
this.analyticsLogger.logProjectionChanged(projection.projectionType); this.analyticsLogger.logProjectionChanged(projection.projectionType);
if (this.dataSet.projectionCanBeRendered(projection.projectionType)) { }
this.updateScatterPlotAttributes(); this.notifyProjectionChanged(projection);
this.notifyProjectionsUpdated();
} }
this.scatterPlot.setCameraParametersForNextCameraCreation(null, false); notifyProjectionPositionsUpdated() {
this.notifyProjectionChanged(this.dataSet); this.projectorScatterPlotAdapter.notifyProjectionPositionsUpdated();
}
notifyProjectionsUpdated() {
this.updateScatterPlotPositions();
this.scatterPlot.render();
} }
/** /**
@ -547,7 +482,7 @@ export class Projector extends ProjectorPolymer implements
state.tSNEIteration = this.dataSet.tSNEIteration; state.tSNEIteration = this.dataSet.tSNEIteration;
state.selectedPoints = this.selectedPointIndices; state.selectedPoints = this.selectedPointIndices;
state.filteredPoints = this.dataSetFilterIndices; state.filteredPoints = this.dataSetFilterIndices;
state.cameraDef = this.scatterPlot.getCameraDef(); this.projectorScatterPlotAdapter.populateBookmarkFromUI(state);
state.selectedColorOptionName = this.dataPanel.selectedColorOptionName; state.selectedColorOptionName = this.dataPanel.selectedColorOptionName;
state.selectedLabelOption = this.selectedLabelOption; state.selectedLabelOption = this.selectedLabelOption;
this.projectionsPanel.populateBookmarkFromUI(state); this.projectionsPanel.populateBookmarkFromUI(state);
@ -556,6 +491,7 @@ export class Projector extends ProjectorPolymer implements
/** Loads a State object into the world. */ /** Loads a State object into the world. */
loadState(state: State) { loadState(state: State) {
this.setProjection(null);
{ {
this.projectionsPanel.disablePolymerChangesTriggerReprojection(); this.projectionsPanel.disablePolymerChangesTriggerReprojection();
this.resetFilterDataset(); this.resetFilterDataset();
@ -578,23 +514,17 @@ export class Projector extends ProjectorPolymer implements
this.inspectorPanel.restoreUIFromBookmark(state); this.inspectorPanel.restoreUIFromBookmark(state);
this.dataPanel.selectedColorOptionName = state.selectedColorOptionName; this.dataPanel.selectedColorOptionName = state.selectedColorOptionName;
this.selectedLabelOption = state.selectedLabelOption; this.selectedLabelOption = state.selectedLabelOption;
this.scatterPlot.setCameraParametersForNextCameraCreation( this.projectorScatterPlotAdapter.restoreUIFromBookmark(state);
state.cameraDef, false);
{ {
const dimensions = stateGetAccessorDimensions(state); const dimensions = stateGetAccessorDimensions(state);
const accessors = const accessors =
this.dataSet.getPointAccessors(state.selectedProjection, dimensions); this.dataSet.getPointAccessors(state.selectedProjection, dimensions);
const projection = new Projection( const projection = new Projection(
state.selectedProjection, accessors, dimensions.length); state.selectedProjection, accessors, dimensions.length, this.dataSet);
this.setProjection(projection); this.setProjection(projection);
} }
this.notifySelectionChanged(state.selectedPoints); this.notifySelectionChanged(state.selectedPoints);
} }
notifyDistanceMetricChanged(distMetric: DistanceFunction) {
this.updateScatterPlotAttributes();
this.scatterPlot.render();
}
} }
document.registerElement(Projector.prototype.is, Projector); document.registerElement(Projector.prototype.is, Projector);