fix: new gpu api for graphinForce

This commit is contained in:
Yanyan-Wang 2020-11-06 20:58:19 +08:00 committed by Yanyan Wang
parent b487f02332
commit 906d1653c0
8 changed files with 101 additions and 91 deletions

View File

@ -49,8 +49,8 @@ const louvain = (
// the origin data
const { nodes, edges } = data;
let clusters = {};
let nodeMap = {};
const clusters = {};
const nodeMap = {};
// init the clusters and nodeMap
nodes.forEach((node, i) => {
const cid: string = uniqueId();
@ -60,7 +60,7 @@ const louvain = (
nodes: [node]
};
nodeMap[node.id] = {
node: node,
node,
idx: i
};
});

View File

@ -272,7 +272,7 @@ export default class LayoutController {
graph.emit('beforelayout');
const offScreenCanvas = document.createElement('canvas');
const gpuWorkerAbility = isGPU && !navigator['gpu'] && // WebGPU 还不支持 OffscreenCanvas
const gpuWorkerAbility = isGPU && !navigator[`gpu`] && // WebGPU 还不支持 OffscreenCanvas
'OffscreenCanvas' in window && 'transferControlToOffscreen' in offScreenCanvas;

View File

@ -346,7 +346,7 @@ export default class TreeGraph extends Graph implements ITreeGraph {
// 如果没有父节点或找不到该节点是全量的更新直接重置data
if (!parentId || !self.findById(parentId)) {
console.warn(`Update children failed! There is no node with id \'${parentId}\'`);
console.warn(`Update children failed! There is no node with id '${parentId}'`);
return;
}

View File

@ -31,31 +31,42 @@ type NodeMap = {
export default class FruchtermanGPULayout extends BaseLayout {
/** 布局中心 */
public center: IPointTuple = [0, 0];
/** 停止迭代的最大迭代数 */
public maxIteration: number = 1000;
/** 重力大小,影响图的紧凑程度 */
public gravity: number = 10;
/** 速度 */
public speed: number = 1;
/** 是否产生聚类力 */
public clustering: boolean = false;
/** 根据哪个字段聚类 */
public clusterField: string = 'cluster';
/** 聚类力大小 */
public clusterGravity: number = 10;
/** 是否启用web worker。前提是在web worker里执行布局否则无效 */
public workerEnabled: boolean = false;
public nodes: Node[] = [];
public edges: Edge[] = [];
public width: number = 300;
public height: number = 300;
public nodeMap: NodeMap = {};
public nodeIdxMap: NodeIdxMap = {};
public canvasEl: HTMLCanvasElement;
public onLayoutEnd: () => void;
public getDefaultCfg() {
@ -68,6 +79,7 @@ export default class FruchtermanGPULayout extends BaseLayout {
clusterGravity: 10,
};
}
/**
*
*/
@ -97,6 +109,7 @@ export default class FruchtermanGPULayout extends BaseLayout {
// layout
self.run();
}
public executeWithWorker(canvas?: HTMLCanvasElement, ctx?: any) {
const self = this;
const nodes = self.nodes;
@ -148,7 +161,7 @@ export default class FruchtermanGPULayout extends BaseLayout {
const numParticles = nodes.length;
const { maxEdgePerVetex, array: nodesEdgesArray } = buildTextureData(nodes, edges);
let workerEnabled = self.workerEnabled;
const workerEnabled = self.workerEnabled;
let world;

View File

@ -9,9 +9,10 @@ import { isNumber } from '@antv/util';
import { World } from '@antv/g-webgpu';
import { proccessToFunc, buildTextureDataWithTwoEdgeAttr, arrayToTextureData } from '../../util/layout'
import { getDegree } from '../../util/math'
import { gCode, cCode } from './graphinForceShader';
import { graphinForceCode } from './graphinForceShader';
// import { graphinForceCode, graphinForceBundle } from './graphinForceShader';
import { LAYOUT_MESSAGE } from '../worker/layoutConst';
import { Compiler } from '@antv/g-webgpu-compiler';
type NodeMap = {
[key: string]: NodeConfig;
@ -23,42 +24,59 @@ type NodeMap = {
export default class GraphinForceGPULayout extends BaseLayout {
/** 布局中心 */
public center: IPointTuple = [0, 0];
/** 停止迭代的最大迭代数 */
public maxIteration: number = 1000;
/** 弹簧引力系数 */
public edgeStrength: number | ((d?: any) => number) | undefined = 200;
/** 斥力系数 */
public nodeStrength: number | ((d?: any) => number) | undefined = 1000;
/** 库伦系数 */
public coulombDisScale: number = 0.005;
/** 阻尼系数 */
public damping: number = 0.9;
/** 最大速度 */
public maxSpeed: number = 1000;
/** 一次迭代的平均移动距离小于该值时停止迭代 */
public minMovement: number = 0.5;
/** 迭代中衰减 */
public interval: number = 0.02;
/** 斥力的一个系数 */
public factor: number = 1;
/** 每个节点质量的回调函数,若不指定,则默认使用度数作为节点质量 */
public getMass: ((d?: any) => number) | undefined;
/** 理想边长 */
public linkDistance: number | ((d?: any) => number) | undefined = 1;
/** 重力大小 */
public gravity: number = 10;
/** 每个节点中心力的 x、y、强度的回调函数若不指定则没有额外中心力 */
public getCenter: ((d?: any, degree?: number) => number[]) | undefined;
/** 是否启用web worker。前提是在web worker里执行布局否则无效 */
public workerEnabled: boolean = false;
public nodes: NodeConfig[] = [];
public edges: EdgeConfig[] = [];
public width: number = 300;
public height: number = 300;
public nodeMap: NodeMap = {};
public nodeIdxMap: NodeIdxMap = {};
public onLayoutEnd: () => void;
@ -74,6 +92,7 @@ export default class GraphinForceGPULayout extends BaseLayout {
clusterGravity: 10,
};
}
/**
*
*/
@ -195,101 +214,80 @@ export default class GraphinForceGPULayout extends BaseLayout {
masses, self.degrees, nodeStrengths,
centerXs, centerYs, centerGravities
]);
console.log('nodeAttributeArray', nodeAttributeArray, numParticles, maxEdgePerVetex)
let workerEnabled = self.workerEnabled;
let world;
if (workerEnabled) {
world = new World({
world = World.create({
canvas,
engineOptions: {
supportCompute: true,
},
});
} else {
world = new World({
world = World.create({
engineOptions: {
supportCompute: true,
}
});
}
const compute1 = world.createComputePipeline({
shader: gCode,
onCompleted: (result) => {
// 获取 Shader 的编译结果,用户可以输出到 console 中并保存
console.log(world.getPrecompiledBundle(compute));
},
});
// TODO: 最终的预编译代码放入到 graphinForceShader.ts 中直接引入,不再需要下面三行
const compiler = new Compiler();
const graphinForceBundle = compiler.compileBundle(graphinForceCode);
console.log(graphinForceBundle);
const onLayoutEnd = self.onLayoutEnd;
const compute = world.createComputePipeline({
shader: gCode,
// precompiled: true,
// shader: cCode,
dispatch: [numParticles, 1, 1],
maxIteration,//maxIteration,
onIterationCompleted: async (iter) => {
const stepInterval = Math.max(0.02, self.interval - iter * 0.002);
world.setBinding(
compute,
'u_interval',
stepInterval,
);
},
onCompleted: (finalParticleData) => {
if (canvas) {
// 传递数据给主线程
ctx.postMessage({
type: LAYOUT_MESSAGE.GPUEND,
vertexEdgeData: finalParticleData,
// edgeIndexBufferData,
});
} else {
nodes.forEach((node, i) => {
const x = finalParticleData[4 * i];
const y = finalParticleData[4 * i + 1];
node.x = x;
node.y = y;
});
}
const kernelGraphinForce = world
.createKernel(graphinForceBundle)
.setDispatch([numParticles, 1, 1])
.setBinding({
u_Data: nodesEdgesArray, // 节点边输入输出
u_damping: self.damping,
u_maxSpeed: self.maxSpeed,
u_minMovement: self.minMovement,
u_coulombDisScale: self.coulombDisScale,
u_factor: self.factor,
u_NodeAttributeArray: nodeAttributeArray,
MAX_EDGE_PER_VERTEX: maxEdgePerVetex,
VERTEX_COUNT: numParticles,
u_interval: self.interval // 每次迭代更新,首次设置为 interval在 onIterationCompleted 中更新
});
onLayoutEnd && onLayoutEnd();
// 执行迭代
const execute = async () => {
for (let i = 0; i < maxIteration; i++) {
await kernelGraphinForce.execute();
// 每次迭代完成后
const stepInterval = Math.max(0.02, self.interval - i * 0.002);
kernelGraphinForce.setBinding({
u_interval: stepInterval
});
}
const finalParticleData = await kernelGraphinForce.getOutput();
// 计算完成后销毁相关 GPU 资源
world.destroy();
},
});
// 所有迭代完成后
if (canvas) {
// 传递数据给主线程
ctx.postMessage({
type: LAYOUT_MESSAGE.GPUEND,
vertexEdgeData: finalParticleData,
// edgeIndexBufferData,
});
} else {
nodes.forEach((node, i) => {
const x = finalParticleData[4 * i];
const y = finalParticleData[4 * i + 1];
node.x = x;
node.y = y;
});
}
// 节点边输入输出
world.setBinding(compute, 'u_Data', nodesEdgesArray);
// // 布局中心
// world.setBinding(compute, 'u_CenterX', self.center[0]);
// world.setBinding(compute, 'u_CenterY', self.center[1]);
onLayoutEnd && onLayoutEnd();
}
// // 中心力
// world.setBinding(compute, 'u_gravity', self.gravity);
// // 聚集离散点
// world.setBinding(compute, 'u_gatherDiscrete', self.gatherDiscreteCenter ? 1 : 0);
// world.setBinding(compute, 'u_GatherDiscreteCenterX', self.gatherDiscreteCenter[0]);
// world.setBinding(compute, 'u_GatherDiscreteCenterY', self.gatherDiscreteCenter[1]);
// world.setBinding(compute, 'u_GatherDiscreteGravity', self.gatherDiscreteGravity);
// 常量
// world.setBinding(compute, 'u_stiffness', self.stiffness);
world.setBinding(compute, 'u_damping', self.damping);
world.setBinding(compute, 'u_maxSpeed', self.maxSpeed);
world.setBinding(compute, 'u_minMovement', self.minMovement);
world.setBinding(compute, 'u_coulombDisScale', self.coulombDisScale);
// world.setBinding(compute, 'u_repulsion', self.repulsion);
world.setBinding(compute, 'u_factor', self.factor);
world.setBinding(compute, 'u_NodeAttributeArray', nodeAttributeArray);
world.setBinding(compute, 'MAX_EDGE_PER_VERTEX', maxEdgePerVetex);
world.setBinding(compute, 'VERTEX_COUNT', numParticles);
// 每次迭代更新,首次设置为 interval在 onIterationCompleted 中更新
world.setBinding(compute, 'u_interval', self.interval);
execute();
}
}

File diff suppressed because one or more lines are too long

View File

@ -111,11 +111,11 @@ export default class GraphinForceLayout extends BaseLayout {
let nodeSizeFunc;
if (self.preventOverlap) {
const nodeSpacing = self.nodeSpacing;
let nodeSpacingFunc: Function;
let nodeSpacingFunc: ((d?: any) => number);
if (isNumber(nodeSpacing)) {
nodeSpacingFunc = () => nodeSpacing;
nodeSpacingFunc = () => nodeSpacing as number;
} else if (isFunction(nodeSpacing)) {
nodeSpacingFunc = nodeSpacing;
nodeSpacingFunc = nodeSpacing as ((d?: any) => number);
} else {
nodeSpacingFunc = () => 0;
}
@ -126,7 +126,7 @@ export default class GraphinForceLayout extends BaseLayout {
const res = d.size[0] > d.size[1] ? d.size[0] : d.size[1];
return res + nodeSpacingFunc(d);
}
return d.size + nodeSpacingFunc(d);
return (d.size as number) + nodeSpacingFunc(d);
}
return 10 + nodeSpacingFunc(d);
};
@ -136,7 +136,7 @@ export default class GraphinForceLayout extends BaseLayout {
return res + nodeSpacingFunc(d);
};
} else {
nodeSizeFunc = (d: NodeConfig) => nodeSize + nodeSpacingFunc(d);
nodeSizeFunc = (d: NodeConfig) => (nodeSize as number) + nodeSpacingFunc(d);
}
}
self.nodeSize = nodeSizeFunc;

View File

@ -239,11 +239,10 @@ export const gpuDetector = (): any => {
return element;
},
addGetWebGLMessage: function (parameters) {
let parent, id, element;
parameters = parameters || {};
parent = parameters.parent !== undefined ? parameters.parent : document.body;
id = parameters.id !== undefined ? parameters.id : 'oldie';
element = gpuDetector().getWebGLErrorMessage();
const parent = parameters.parent !== undefined ? parameters.parent : document.body;
const id = parameters.id !== undefined ? parameters.id : 'oldie';
const element = gpuDetector().getWebGLErrorMessage();
element.id = id;
parent.appendChild(element);
}