feat: add decision tree demo

This commit is contained in:
yvonneyx 2024-08-06 16:48:55 +08:00
parent 7ce727aacd
commit c516b3de04
6 changed files with 721 additions and 3 deletions

View File

@ -0,0 +1,192 @@
{
"id": "g1",
"name": "Name1",
"count": 123456,
"label": "538.90",
"currency": "Yuan",
"rate": 1.0,
"status": "B",
"variableName": "V1",
"variableValue": 0.341,
"variableUp": false,
"children": [
{
"id": "g12",
"name": "Deal with LONG label LONG label LONG label LONG label",
"count": 123456,
"label": "338.00",
"rate": 0.627,
"status": "R",
"currency": "Yuan",
"variableName": "V2",
"variableValue": 0.179,
"variableUp": true,
"children": [
{
"id": "g121",
"name": "Name3",
"collapsed": true,
"count": 123456,
"label": "138.00",
"rate": 0.123,
"status": "B",
"currency": "Yuan",
"variableName": "V2",
"variableValue": 0.27,
"variableUp": true,
"children": [
{
"id": "g1211",
"name": "Name4",
"count": 123456,
"label": "138.00",
"rate": 1.0,
"status": "B",
"currency": "Yuan",
"variableName": "V1",
"variableValue": 0.164,
"variableUp": false,
"children": []
}
]
},
{
"id": "g122",
"name": "Name5",
"collapsed": true,
"count": 123456,
"label": "100.00",
"rate": 0.296,
"status": "G",
"currency": "Yuan",
"variableName": "V1",
"variableValue": 0.259,
"variableUp": true,
"children": [
{
"id": "g1221",
"name": "Name6",
"count": 123456,
"label": "40.00",
"rate": 0.4,
"status": "G",
"currency": "Yuan",
"variableName": "V1",
"variableValue": 0.135,
"variableUp": true,
"children": [
{
"id": "g12211",
"name": "Name6-1",
"count": 123456,
"label": "40.00",
"rate": 1.0,
"status": "R",
"currency": "Yuan",
"variableName": "V1",
"variableValue": 0.181,
"variableUp": true,
"children": []
}
]
},
{
"id": "g1222",
"name": "Name7",
"count": 123456,
"label": "60.00",
"rate": 0.6,
"status": "G",
"currency": "Yuan",
"variableName": "V1",
"variableValue": 0.239,
"variableUp": false,
"children": []
}
]
},
{
"id": "g123",
"name": "Name8",
"collapsed": true,
"count": 123456,
"label": "100.00",
"rate": 0.296,
"status": "DI",
"currency": "Yuan",
"variableName": "V2",
"variableValue": 0.131,
"variableUp": false,
"children": [
{
"id": "g1231",
"name": "Name8-1",
"count": 123456,
"label": "100.00",
"rate": 1.0,
"status": "DI",
"currency": "Yuan",
"variableName": "V2",
"variableValue": 0.131,
"variableUp": false,
"children": []
}
]
}
]
},
{
"id": "g13",
"name": "Name9",
"count": 123456,
"label": "100.90",
"rate": 0.187,
"status": "B",
"currency": "Yuan",
"variableName": "V2",
"variableValue": 0.221,
"variableUp": true,
"children": [
{
"id": "g131",
"name": "Name10",
"count": 123456,
"label": "33.90",
"rate": 0.336,
"status": "R",
"currency": "Yuan",
"variableName": "V1",
"variableValue": 0.12,
"variableUp": true,
"children": []
},
{
"id": "g132",
"name": "Name11",
"count": 123456,
"label": "67.00",
"rate": 0.664,
"status": "G",
"currency": "Yuan",
"variableName": "V1",
"variableValue": 0.241,
"variableUp": false,
"children": []
}
]
},
{
"id": "g14",
"name": "Name12",
"count": 123456,
"label": "100.00",
"rate": 0.186,
"status": "G",
"currency": "Yuan",
"variableName": "V2",
"variableValue": 0.531,
"variableUp": true,
"children": []
}
]
}

View File

@ -0,0 +1,257 @@
import data from '@@/dataset/decision-tree.json';
import type { DisplayObject, RectStyleProps as GRectStyleProps, TextStyleProps as GTextStyleProps } from '@antv/g';
import { Rect as GRect, Group, Text as GText } from '@antv/g';
import type { BadgeStyleProps, LabelStyleProps, NodeData, RectStyleProps } from '@antv/g6';
import {
Badge,
CommonEvent,
ExtensionCategory,
Graph,
GraphEvent,
Label,
Rect,
register,
treeToGraphData,
} from '@antv/g6';
import { TreeData } from '../../src/types';
export const caseDecisionTree: TestCase = async (context) => {
const COLORS: Record<string, string> = {
B: '#1783FF',
R: '#F46649',
Y: '#DB9D0D',
G: '#60C42D',
DI: '#A7A7A7',
};
const GREY_COLOR = '#CED4D9';
const NODE_HEIGHT = 60;
const NODE_WIDTH = 202;
const NODE_RADIUS = 4;
class TreeNode extends Rect {
get data() {
return this.context.model.getNodeData([this.id])[0] as Record<string, string>;
}
get childrenData() {
return this.context.model.getChildrenData(this.id);
}
protected getLabelStyle(attributes: Required<RectStyleProps>): LabelStyleProps {
return {
text: this.data.name,
fontSize: 12,
opacity: 0.85,
fill: '#000',
cursor: 'pointer',
};
}
protected getPriceStyle(attributes: Required<RectStyleProps>): GTextStyleProps {
return {
y: NODE_HEIGHT - 24,
text: this.data.label,
fontSize: 16,
fill: '#000',
opacity: 0.85,
};
}
protected drawPriceShape(attributes: Required<RectStyleProps>, container: Group) {
const priceStyle = this.getPriceStyle(attributes);
this.upsert('price', GText, priceStyle, container);
}
protected getCurrencyStyle(attributes: Required<RectStyleProps>): GTextStyleProps {
return {
x: this.shapeMap['price'].getLocalBounds().max[0] + 4,
y: NODE_HEIGHT - 24,
text: this.data.currency,
fontSize: 12,
fill: '#000',
opacity: 0.75,
};
}
protected drawCurrencyShape(attributes: Required<RectStyleProps>, container: Group) {
const currencyStyle = this.getCurrencyStyle(attributes);
this.upsert('currency', GText, currencyStyle, container);
}
protected getPercentStyle(attributes: Required<RectStyleProps>): GTextStyleProps {
return {
x: NODE_WIDTH - 24,
y: NODE_HEIGHT - 24,
text: `${((Number(this.data.variableValue) || 0) * 100).toFixed(2)}%`,
fontSize: 12,
textAlign: 'right',
fill: COLORS[this.data.status],
};
}
protected drawPercentShape(attributes: Required<RectStyleProps>, container: Group) {
const percentStyle = this.getPercentStyle(attributes);
this.upsert('percent', GText, percentStyle, container);
}
protected getTriangleStyle(attributes: Required<RectStyleProps>): LabelStyleProps {
const percentMinX = this.shapeMap['percent'].getLocalBounds().min[0];
return {
fill: COLORS[this.data.status],
x: this.data.variableUp ? percentMinX - 18 : percentMinX,
y: NODE_HEIGHT - 32,
fontFamily: 'iconfont',
fontSize: 16,
text: '\ue62d',
transform: this.data.variableUp ? '' : 'rotate(180deg)',
};
}
protected drawTriangleShape(attributes: Required<RectStyleProps>, container: Group) {
const triangleStyle = this.getTriangleStyle(attributes);
this.upsert('triangle', Label, triangleStyle, container);
}
protected getVariableStyle(attributes: Required<RectStyleProps>): GTextStyleProps {
return {
fill: '#000',
fontSize: 12,
opacity: 0.45,
text: this.data.variableName,
textAlign: 'right',
x: this.shapeMap['triangle'].getLocalBounds().min[0] - 4,
y: NODE_HEIGHT - 24,
};
}
protected drawVariableShape(attributes: Required<RectStyleProps>, container: Group) {
const variableStyle = this.getVariableStyle(attributes);
this.upsert('variable', GText, variableStyle, container);
}
protected getCollapseStyle(attributes: Required<RectStyleProps>): BadgeStyleProps | false {
if (this.childrenData.length === 0) return false;
const { collapsed } = attributes;
return {
backgroundFill: '#fff',
backgroundHeight: 16,
backgroundLineWidth: 1,
backgroundRadius: 0,
backgroundStroke: GREY_COLOR,
backgroundWidth: 16,
cursor: 'pointer',
fill: GREY_COLOR,
fontSize: 16,
text: collapsed ? '+' : '-',
textAlign: 'center',
textBaseline: 'middle',
x: NODE_WIDTH - 16,
y: NODE_HEIGHT / 2 - 16,
};
}
protected drawCollapseShape(attributes: Required<RectStyleProps>, container: Group) {
const collapseStyle = this.getCollapseStyle(attributes);
const btn = this.upsert('collapse', Badge, collapseStyle, container);
this.forwardEvent(btn, CommonEvent.CLICK, () => {
const { collapsed } = this.attributes;
const graph = this.context.graph;
if (collapsed) graph.expandElement(this.id);
else graph.collapseElement(this.id);
});
}
private forwardEvent(target: DisplayObject | undefined, type: string, listener: (event: any) => void) {
if (target && !Reflect.has(target, '__bind__')) {
Reflect.set(target, '__bind__', true);
target.addEventListener(type, listener);
}
}
protected getProcessBarStyle(attributes: Required<RectStyleProps>): GRectStyleProps {
const { rate, status } = this.data;
const color = COLORS[status];
const percent = `${rate * 100}%`;
return {
x: -16,
y: NODE_HEIGHT - 20,
width: NODE_WIDTH,
height: 4,
radius: [0, 0, NODE_RADIUS, NODE_RADIUS],
fill: `linear-gradient(to right, ${color} ${percent}, ${GREY_COLOR} ${percent})`,
};
}
protected drawProcessBarShape(attributes: Required<RectStyleProps>, container: Group) {
const processBarStyle = this.getProcessBarStyle(attributes);
this.upsert('process-bar', GRect, processBarStyle, container);
}
protected getKeyStyle(attributes: Required<RectStyleProps>): GRectStyleProps {
const keyStyle = super.getKeyStyle(attributes);
return {
...keyStyle,
fill: '#fff',
height: NODE_HEIGHT,
width: NODE_WIDTH,
lineWidth: 1,
radius: NODE_RADIUS,
stroke: GREY_COLOR,
};
}
public render(attributes: Required<RectStyleProps> = this.parsedAttributes, container: Group) {
super.render(attributes, container);
this.drawPriceShape(attributes, container);
this.drawCurrencyShape(attributes, container);
this.drawPercentShape(attributes, container);
this.drawTriangleShape(attributes, container);
this.drawVariableShape(attributes, container);
this.drawProcessBarShape(attributes, container);
this.drawCollapseShape(attributes, container);
}
}
register(ExtensionCategory.NODE, 'tree-node', TreeNode);
const graph = new Graph({
...context,
data: treeToGraphData(data, {
getNodeData: (datum: TreeData, depth: number) => {
if (!datum.style) datum.style = {};
datum.style.collapsed = depth >= 2;
if (!datum.children) return datum as NodeData;
const { children, ...restDatum } = datum;
return { ...restDatum, children: children.map((child) => child.id) } as NodeData;
},
}),
node: {
type: 'tree-node',
style: { ports: [{ placement: 'left' }, { placement: 'right' }] },
},
edge: {
type: 'cubic-horizontal',
style: {
stroke: GREY_COLOR,
},
},
layout: {
type: 'indented',
direction: 'LR',
dropCap: false,
indent: NODE_WIDTH + 100,
getHeight: () => NODE_HEIGHT,
},
behaviors: ['zoom-canvas', 'drag-canvas'],
});
graph.once(GraphEvent.AFTER_RENDER, () => {
graph.fitView();
});
await graph.render();
return graph;
};

View File

@ -19,6 +19,7 @@ export { behaviorLassoSelect } from './behavior-lasso-select';
export { behaviorOptimizeViewportTransform } from './behavior-optimize-viewport-transform';
export { behaviorScrollCanvas } from './behavior-scroll-canvas';
export { behaviorZoomCanvas } from './behavior-zoom-canvas';
export { caseDecisionTree } from './case-decision-tree';
export { caseIndentedTree } from './case-indented-tree';
export { caseOrgChart } from './case-org-chart';
export { commonGraph } from './common-graph';

View File

@ -3,7 +3,7 @@ import type { TreeData } from '../types';
import { dfs } from './traverse';
type TreeDataGetter = {
getNodeData?: (datum: TreeData) => NodeData;
getNodeData?: (datum: TreeData, depth: number) => NodeData;
getEdgeData?: (source: TreeData, target: TreeData) => EdgeData;
getChildren?: (datum: TreeData) => TreeData[];
};
@ -32,8 +32,8 @@ export function treeToGraphData(treeData: TreeData, getter?: TreeDataGetter): Gr
dfs(
treeData,
(node) => {
nodes.push(getNodeData(node));
(node, depth) => {
nodes.push(getNodeData(node, depth));
const children = getChildren(node);
for (const child of children) {
edges.push(getEdgeData(node, child));

View File

@ -0,0 +1,260 @@
import { Rect as GRect, Text as GText } from '@antv/g';
import {
Badge,
CommonEvent,
ExtensionCategory,
Graph,
GraphEvent,
iconfont,
Label,
Rect,
register,
treeToGraphData,
} from '@antv/g6';
const style = document.createElement('style');
style.innerHTML = `@import url('${iconfont.css}');`;
document.head.appendChild(style);
const COLORS = {
B: '#1783FF',
R: '#F46649',
Y: '#DB9D0D',
G: '#60C42D',
DI: '#A7A7A7',
};
const GREY_COLOR = '#CED4D9';
const NODE_HEIGHT = 60;
const NODE_WIDTH = 202;
const NODE_RADIUS = 4;
class TreeNode extends Rect {
get data() {
return this.context.model.getNodeData([this.id])[0];
}
get childrenData() {
return this.context.model.getChildrenData(this.id);
}
getLabelStyle() {
return {
text: this.data.name,
fontSize: 12,
opacity: 0.85,
fill: '#000',
cursor: 'pointer',
};
}
getPriceStyle() {
return {
y: NODE_HEIGHT - 24,
text: this.data.label,
fontSize: 16,
fill: '#000',
opacity: 0.85,
};
}
drawPriceShape(attributes, container) {
const priceStyle = this.getPriceStyle(attributes);
this.upsert('price', GText, priceStyle, container);
}
getCurrencyStyle() {
return {
x: this.shapeMap['price'].getLocalBounds().max[0] + 4,
y: NODE_HEIGHT - 24,
text: this.data.currency,
fontSize: 12,
fill: '#000',
opacity: 0.75,
};
}
drawCurrencyShape(attributes, container) {
const currencyStyle = this.getCurrencyStyle(attributes);
this.upsert('currency', GText, currencyStyle, container);
}
getPercentStyle() {
return {
x: NODE_WIDTH - 24,
y: NODE_HEIGHT - 24,
text: `${((Number(this.data.variableValue) || 0) * 100).toFixed(2)}%`,
fontSize: 12,
textAlign: 'right',
fill: COLORS[this.data.status],
};
}
drawPercentShape(attributes, container) {
const percentStyle = this.getPercentStyle(attributes);
this.upsert('percent', GText, percentStyle, container);
}
getTriangleStyle() {
const percentMinX = this.shapeMap['percent'].getLocalBounds().min[0];
return {
fill: COLORS[this.data.status],
x: this.data.variableUp ? percentMinX - 18 : percentMinX,
y: NODE_HEIGHT - 32,
fontFamily: 'iconfont',
fontSize: 16,
text: '\ue62d',
transform: this.data.variableUp ? '' : 'rotate(180deg)',
};
}
drawTriangleShape(attributes, container) {
const triangleStyle = this.getTriangleStyle(attributes);
this.upsert('triangle', Label, triangleStyle, container);
}
getVariableStyle() {
return {
fill: '#000',
fontSize: 12,
opacity: 0.45,
text: this.data.variableName,
textAlign: 'right',
x: this.shapeMap['triangle'].getLocalBounds().min[0] - 4,
y: NODE_HEIGHT - 24,
};
}
drawVariableShape(attributes, container) {
const variableStyle = this.getVariableStyle(attributes);
this.upsert('variable', GText, variableStyle, container);
}
getCollapseStyle(attributes) {
if (this.childrenData.length === 0) return false;
const { collapsed } = attributes;
return {
backgroundFill: '#fff',
backgroundHeight: 16,
backgroundLineWidth: 1,
backgroundRadius: 0,
backgroundStroke: GREY_COLOR,
backgroundWidth: 16,
cursor: 'pointer',
fill: GREY_COLOR,
fontSize: 16,
text: collapsed ? '+' : '-',
textAlign: 'center',
textBaseline: 'middle',
x: NODE_WIDTH - 16,
y: NODE_HEIGHT / 2 - 16,
};
}
drawCollapseShape(attributes, container) {
const collapseStyle = this.getCollapseStyle(attributes);
const btn = this.upsert('collapse', Badge, collapseStyle, container);
this.forwardEvent(btn, CommonEvent.CLICK, () => {
const { collapsed } = this.attributes;
const graph = this.context.graph;
if (collapsed) graph.expandElement(this.id);
else graph.collapseElement(this.id);
});
}
forwardEvent(target, type, listener) {
if (target && !Reflect.has(target, '__bind__')) {
Reflect.set(target, '__bind__', true);
target.addEventListener(type, listener);
}
}
getProcessBarStyle() {
const { rate, status } = this.data;
const color = COLORS[status];
const percent = `${rate * 100}%`;
return {
x: -16,
y: NODE_HEIGHT - 20,
width: NODE_WIDTH,
height: 4,
radius: [0, 0, NODE_RADIUS, NODE_RADIUS],
fill: `linear-gradient(to right, ${color} ${percent}, ${GREY_COLOR} ${percent})`,
};
}
drawProcessBarShape(attributes, container) {
const processBarStyle = this.getProcessBarStyle(attributes);
this.upsert('process-bar', GRect, processBarStyle, container);
}
getKeyStyle(attributes) {
const keyStyle = super.getKeyStyle(attributes);
return {
...keyStyle,
fill: '#fff',
height: NODE_HEIGHT,
width: NODE_WIDTH,
lineWidth: 1,
radius: NODE_RADIUS,
stroke: GREY_COLOR,
};
}
render(attributes = this.parsedAttributes, container) {
super.render(attributes, container);
this.drawPriceShape(attributes, container);
this.drawCurrencyShape(attributes, container);
this.drawPercentShape(attributes, container);
this.drawTriangleShape(attributes, container);
this.drawVariableShape(attributes, container);
this.drawProcessBarShape(attributes, container);
this.drawCollapseShape(attributes, container);
}
}
register(ExtensionCategory.NODE, 'tree-node', TreeNode);
fetch('https://assets.antv.antgroup.com/g6/decision-tree.json')
.then((res) => res.json())
.then((data) => {
const graph = new Graph({
container: 'container',
data: treeToGraphData(data, {
getNodeData: (datum, depth) => {
if (!datum.style) datum.style = {};
// 层级大于 1 的节点默认收起
// Nodes with a depth greater than 2 are collapsed by default
datum.style.collapsed = depth >= 2;
if (!datum.children) return datum;
const { children, ...restDatum } = datum;
return { ...restDatum, children: children.map((child) => child.id) };
},
}),
node: {
type: 'tree-node',
style: { ports: [{ placement: 'left' }, { placement: 'right' }] },
},
edge: {
type: 'cubic-horizontal',
style: {
stroke: GREY_COLOR,
},
},
layout: {
type: 'indented',
direction: 'LR',
dropCap: false,
indent: NODE_WIDTH + 100,
getHeight: () => NODE_HEIGHT,
},
behaviors: ['zoom-canvas', 'drag-canvas'],
});
graph.once(GraphEvent.AFTER_RENDER, () => {
graph.fitView();
});
graph.render();
});

View File

@ -28,6 +28,14 @@
},
"screenshot": "https://mdn.alipayobjects.com/huamei_qa8qxu/afts/img/A*r0-SS5dRxykAAAAAAAAAAAAADmJ7AQ/original"
},
{
"filename": "decision-tree.js",
"title": {
"zh": "决策树",
"en": "Decision Tree"
},
"screenshot": "https://mdn.alipayobjects.com/huamei_qa8qxu/afts/img/A*ImBoQIveCtYAAAAAAAAAAAAADmJ7AQ/original"
},
{
"filename": "organization-chart.js",
"title": {