mirror of
https://gitee.com/antv/g6.git
synced 2024-11-30 18:58:34 +08:00
feat: add minimum spanning tree algorithm
This commit is contained in:
parent
30dae14260
commit
83501b91c7
@ -50,7 +50,7 @@
|
||||
"site:deploy": "npm run site:build && gh-pages -d public",
|
||||
"start": "npm run site:develop",
|
||||
"test": "jest",
|
||||
"test-live": "DEBUG_MODE=1 jest --watch ./tests/unit/algorithm/find-path-spec",
|
||||
"test-live": "DEBUG_MODE=1 jest --watch ./tests/unit/algorithm/mst-spec.ts",
|
||||
"lint-staged:js": "eslint --ext .js,.jsx,.ts,.tsx",
|
||||
"watch": "father build -w",
|
||||
"cdn": "antv-bin upload -n @antv/g6"
|
||||
|
@ -7,3 +7,4 @@ export { default as floydWarshall } from './floydWarshall'
|
||||
export { default as getConnectedComponents } from './connected-component'
|
||||
export { detectAllCycles, detectAllDirectedCycle, detectAllUndirectedCycle } from './detect-cycle'
|
||||
export { findShortestPath, findAllPath } from './find-path'
|
||||
export { default as minimumSpanningTree } from './mst'
|
||||
|
114
src/algorithm/mst.ts
Normal file
114
src/algorithm/mst.ts
Normal file
@ -0,0 +1,114 @@
|
||||
import { IGraph } from '../interface/graph';
|
||||
import { IEdge } from '../interface/item';
|
||||
import UnionFind from './structs/union-find';
|
||||
import MinBinaryHeap from './structs/binary-heap'
|
||||
|
||||
/**
|
||||
* Prim algorithm,use priority queue,复杂度 O(E+V*logV), V: 节点数量,E: 边的数量
|
||||
* refer: https://en.wikipedia.org/wiki/Prim%27s_algorithm
|
||||
* @param graph
|
||||
* @param weight 指定用于作为边权重的属性,若不指定,则认为所有边权重一致
|
||||
*/
|
||||
const primMST = (graph: IGraph, weight?: string) => {
|
||||
const selectedEdges = []
|
||||
const nodes = graph.getNodes()
|
||||
if (nodes.length === 0) {
|
||||
return selectedEdges;
|
||||
}
|
||||
|
||||
// 从nodes[0]开始
|
||||
let currNode = nodes[0];
|
||||
let visited = new Set()
|
||||
visited.add(currNode)
|
||||
|
||||
// 用二叉堆维护距已加入节点的其他节点的边的权值
|
||||
const compareWeight = (a: IEdge, b: IEdge) => {
|
||||
if (weight) {
|
||||
return (a.getModel()[weight] as number) - (b.getModel()[weight] as number)
|
||||
} else {
|
||||
return 0
|
||||
}
|
||||
}
|
||||
let edgeQueue = new MinBinaryHeap(compareWeight)
|
||||
currNode.getEdges().forEach(edge => {
|
||||
edgeQueue.insert(edge);
|
||||
});
|
||||
|
||||
while (!edgeQueue.isEmpty()) {
|
||||
// 选取与已加入的结点之间边权最小的结点
|
||||
// console.log(edgeQueue.list.map(edge => edge.getModel().weight))
|
||||
let currEdge = edgeQueue.delMin();
|
||||
const source = currEdge.getSource()
|
||||
const target = currEdge.getTarget()
|
||||
if (visited.has(source) && visited.has(target)) continue
|
||||
selectedEdges.push(currEdge);
|
||||
|
||||
if (!visited.has(source)) {
|
||||
visited.add(source)
|
||||
source.getEdges().forEach(edge => {
|
||||
edgeQueue.insert(edge);
|
||||
});
|
||||
}
|
||||
if (!visited.has(target)) {
|
||||
visited.add(target)
|
||||
target.getEdges().forEach(edge => {
|
||||
edgeQueue.insert(edge);
|
||||
});
|
||||
}
|
||||
}
|
||||
return selectedEdges;
|
||||
}
|
||||
|
||||
/**
|
||||
* Kruskal algorithm,复杂度 O(E*logE), E: 边的数量
|
||||
* refer: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
|
||||
* @param graph
|
||||
* @param weight 指定用于作为边权重的属性,若不指定,则认为所有边权重一致
|
||||
* @return IEdge[] 返回构成MST的边的数组
|
||||
*/
|
||||
const kruskalMST = (graph: IGraph, weight?: string): IEdge[] => {
|
||||
const selectedEdges = []
|
||||
if (graph.getNodes().length === 0) {
|
||||
return selectedEdges;
|
||||
}
|
||||
|
||||
// 若指定weight,则将所有的边按权值从小到大排序
|
||||
const edges = graph.getEdges().map(edge => edge)
|
||||
if (weight) {
|
||||
edges.sort((a, b) => {
|
||||
return a.getModel()[weight] as number - (b.getModel()[weight] as number);
|
||||
})
|
||||
}
|
||||
let disjointSet = new UnionFind(graph.getNodes().map(n => n.get('id')));
|
||||
|
||||
// 从权值最小的边开始,如果这条边连接的两个节点于图G中不在同一个连通分量中,则添加这条边
|
||||
// 直到遍历完所有点或边
|
||||
while (edges.length > 0) {
|
||||
let curEdge = edges.shift();
|
||||
let source = curEdge.getSource().get('id');
|
||||
let target = curEdge.getTarget().get('id');
|
||||
if (!disjointSet.connected(source, target)) {
|
||||
selectedEdges.push(curEdge)
|
||||
disjointSet.union(source, target);
|
||||
}
|
||||
}
|
||||
return selectedEdges;
|
||||
}
|
||||
|
||||
/**
|
||||
* 最小生成树
|
||||
* refer: https://en.wikipedia.org/wiki/Kruskal%27s_algorithm
|
||||
* @param graph
|
||||
* @param weight 指定用于作为边权重的属性,若不指定,则认为所有边权重一致
|
||||
* @param algo 'prim' | 'kruskal' 算法类型
|
||||
* @return IEdge[] 返回构成MST的边的数组
|
||||
*/
|
||||
export default function mst(graph: IGraph, weight?: string, algo?: string) {
|
||||
const algos = {
|
||||
'prim': primMST,
|
||||
'kruskal': kruskalMST,
|
||||
}
|
||||
if (!algo) return kruskalMST(graph, weight)
|
||||
|
||||
return algos[algo](graph, weight)
|
||||
}
|
73
src/algorithm/structs/binary-heap/index.ts
Normal file
73
src/algorithm/structs/binary-heap/index.ts
Normal file
@ -0,0 +1,73 @@
|
||||
const defaultCompare = (a, b) => { return a - b }
|
||||
export default class MinBinaryHeap {
|
||||
list: any[];
|
||||
compareFn: (a: any, b: any) => number;
|
||||
constructor(compareFn = defaultCompare) {
|
||||
this.compareFn = compareFn;
|
||||
this.list = [];
|
||||
}
|
||||
getLeft(index) {
|
||||
return 2 * index + 1;
|
||||
}
|
||||
getRight(index) {
|
||||
return 2 * index + 2;
|
||||
}
|
||||
getParent(index) {
|
||||
if (index === 0) {
|
||||
return null;
|
||||
}
|
||||
return Math.floor((index - 1) / 2);
|
||||
}
|
||||
isEmpty() {
|
||||
return this.list.length <= 0;
|
||||
}
|
||||
top() {
|
||||
return this.isEmpty() ? undefined : this.list[0];
|
||||
}
|
||||
delMin() {
|
||||
let top = this.top()
|
||||
const bottom = this.list.pop();
|
||||
if (this.list.length > 0) {
|
||||
this.list[0] = bottom
|
||||
this.moveDown(0)
|
||||
}
|
||||
return top
|
||||
}
|
||||
insert(value) {
|
||||
if (value !== null) {
|
||||
this.list.push(value);
|
||||
const index = this.list.length - 1;
|
||||
this.moveUp(index);
|
||||
return true;
|
||||
}
|
||||
return false;
|
||||
}
|
||||
moveUp(index) {
|
||||
let parent = this.getParent(index);
|
||||
while (index && index > 0 && this.compareFn(this.list[parent], this.list[index]) > 0) {
|
||||
// swap
|
||||
const tmp = this.list[parent]
|
||||
this.list[parent] = this.list[index]
|
||||
this.list[index] = tmp
|
||||
// [this.list[index], this.list[parent]] = [this.list[parent], this.list[index]]
|
||||
index = parent;
|
||||
parent = this.getParent(index);
|
||||
}
|
||||
}
|
||||
moveDown(index) {
|
||||
let element = index;
|
||||
const left = this.getLeft(index);
|
||||
const right = this.getRight(index);
|
||||
const size = this.list.length;
|
||||
if (left !== null && left < size && this.compareFn(this.list[element], this.list[left]) > 0) {
|
||||
element = left;
|
||||
} else if (right !== null && right < size && this.compareFn(this.list[element], this.list[right]) > 0) {
|
||||
element = right;
|
||||
}
|
||||
if (index !== element) {
|
||||
[this.list[index], this.list[element]] = [this.list[element], this.list[index]]
|
||||
this.moveDown(element);
|
||||
}
|
||||
}
|
||||
}
|
||||
|
43
src/algorithm/structs/union-find/index.ts
Normal file
43
src/algorithm/structs/union-find/index.ts
Normal file
@ -0,0 +1,43 @@
|
||||
/**
|
||||
* 并查集 Disjoint set to support quick union
|
||||
*/
|
||||
export default class UnionFind {
|
||||
count: number;
|
||||
parent: {};
|
||||
constructor(items: (number | string)[]) {
|
||||
this.count = items.length;
|
||||
this.parent = {};
|
||||
for (let i of items) {
|
||||
this.parent[i] = i;
|
||||
}
|
||||
}
|
||||
|
||||
// find the root of the item
|
||||
find(item) {
|
||||
while (this.parent[item] !== item) {
|
||||
item = this.parent[item];
|
||||
}
|
||||
return item;
|
||||
}
|
||||
|
||||
union(a, b) {
|
||||
let rootA = this.find(a);
|
||||
let rootB = this.find(b);
|
||||
|
||||
if (rootA === rootB) return;
|
||||
|
||||
// make the element with smaller root the parent
|
||||
if (rootA < rootB) {
|
||||
if (this.parent[b] != b) this.union(this.parent[b], a);
|
||||
this.parent[b] = this.parent[a];
|
||||
} else {
|
||||
if (this.parent[a] != a) this.union(this.parent[a], b);
|
||||
this.parent[a] = this.parent[b];
|
||||
}
|
||||
}
|
||||
|
||||
// whether a and b are connected, i.e. a and b have the same root
|
||||
connected(a, b) {
|
||||
return this.find(a) === this.find(b);
|
||||
}
|
||||
}
|
130
tests/unit/algorithm/mst-spec.ts
Normal file
130
tests/unit/algorithm/mst-spec.ts
Normal file
@ -0,0 +1,130 @@
|
||||
import G6, { Algorithm } from '../../../src';
|
||||
const { minimumSpanningTree } = Algorithm
|
||||
|
||||
const div = document.createElement('div');
|
||||
div.id = 'container';
|
||||
document.body.appendChild(div);
|
||||
|
||||
const data = {
|
||||
nodes: [
|
||||
{
|
||||
id: 'A'
|
||||
},
|
||||
{
|
||||
id: 'B'
|
||||
},
|
||||
{
|
||||
id: 'C'
|
||||
},
|
||||
{
|
||||
id: 'D'
|
||||
},
|
||||
{
|
||||
id: 'E'
|
||||
},
|
||||
{
|
||||
id: 'F'
|
||||
},
|
||||
{
|
||||
id: 'G'
|
||||
},
|
||||
],
|
||||
edges: [
|
||||
{
|
||||
source: 'A',
|
||||
target: 'B',
|
||||
weight: 1,
|
||||
},
|
||||
{
|
||||
source: 'B',
|
||||
target: 'C',
|
||||
weight: 1,
|
||||
},
|
||||
{
|
||||
source: 'A',
|
||||
target: 'C',
|
||||
weight: 2,
|
||||
},
|
||||
{
|
||||
source: 'D',
|
||||
target: 'A',
|
||||
weight: 3,
|
||||
},
|
||||
{
|
||||
source: 'D',
|
||||
target: 'E',
|
||||
weight: 4,
|
||||
},
|
||||
{
|
||||
source: 'E',
|
||||
target: 'F',
|
||||
weight: 2,
|
||||
},
|
||||
{
|
||||
source: 'F',
|
||||
target: 'D',
|
||||
weight: 3,
|
||||
}
|
||||
]
|
||||
}
|
||||
data.nodes.forEach(node => node['label'] = node.id)
|
||||
data.edges.forEach(edge => edge['label'] = edge.weight)
|
||||
describe('minimumSpanningTree', () => {
|
||||
const graph = new G6.Graph({
|
||||
container: 'container',
|
||||
width: 500,
|
||||
height: 500,
|
||||
layout: {
|
||||
type: 'force'
|
||||
},
|
||||
modes: {
|
||||
default: ['drag-node']
|
||||
},
|
||||
defaultNode: {
|
||||
labelCfg: {
|
||||
style: {
|
||||
fontSize: 12,
|
||||
},
|
||||
},
|
||||
},
|
||||
defaultEdge: {
|
||||
style: {
|
||||
endArrow: true,
|
||||
},
|
||||
labelCfg: {
|
||||
style: {
|
||||
fontSize: 12,
|
||||
},
|
||||
},
|
||||
},
|
||||
edgeStateStyles: {
|
||||
mst: {
|
||||
stroke: 'red'
|
||||
}
|
||||
}
|
||||
})
|
||||
|
||||
graph.data(data)
|
||||
graph.render()
|
||||
|
||||
it('test kruskal algorithm', () => {
|
||||
let result = minimumSpanningTree(graph, 'weight')
|
||||
let totalWeight = 0
|
||||
for (let edge of result) {
|
||||
graph.setItemState(edge, 'mst', true);
|
||||
totalWeight += edge.getModel()['weight']
|
||||
}
|
||||
expect(totalWeight).toEqual(10);
|
||||
});
|
||||
|
||||
it('test prim algorithm', () => {
|
||||
let result = minimumSpanningTree(graph, 'weight', 'prim')
|
||||
let totalWeight = 0
|
||||
for (let edge of result) {
|
||||
graph.setItemState(edge, 'mst', true);
|
||||
totalWeight += edge.getModel()['weight']
|
||||
}
|
||||
expect(totalWeight).toEqual(10);
|
||||
});
|
||||
|
||||
});
|
Loading…
Reference in New Issue
Block a user