This JavaScript program demonstrates how to generate the nodes of a classification tree and graph it with a recursive method.
<!DOCTYPE html> <html> <head> <title>XoaX.net's Javascript</title> <script type="text/javascript" src="ClassificationTree.js"></script> </head> <body onload="Initialize()"> <canvas id="idCanvas" width="600" height ="600" style="background-color: #F0F0F0;"></canvas> </body> </html>
// This is a point in the region [0,1]x[0,1] representing two variable values. // Additionally, the point has a classification value of 0 or 1. class CDataPoint { #mdaValues = [0,0]; #miClass; constructor(dX0, dX1, iClass) { this.#mdaValues[0] = dX0; this.#mdaValues[1] = dX1; this.#miClass = iClass; } get mdX() { return this.#mdaValues; } get miClass() { return this.#miClass; } static Uniform() { return new CDataPoint(Math.random(), Math.random(), ((Math.random() < .5) ? 0 : 1)); } } class CDataSet { #mqaPoints; constructor(qaPoints) { this.#mqaPoints = qaPoints; } Graph(qContext) { let iWidth = qContext.canvas.width; let iHeight = qContext.canvas.height; for (let i = 0; i < this.#mqaPoints.length; ++i) { let qP = this.#mqaPoints[i]; // The points are drawn as red circles or blue squares. if (qP.miClass == 0) { qContext.fillStyle = "red"; qContext.beginPath(); const kdRadius = 3; qContext.ellipse(qP.mdX[0]*iWidth, qP.mdX[1]*iHeight, kdRadius, kdRadius, 0, 0, 2*Math.PI); qContext.fill(); } else { qContext.fillStyle = "blue"; qContext.beginPath(); const kdSide = 6; qContext.rect(qP.mdX[0]*iWidth - kdSide/2, qP.mdX[1]*iHeight - kdSide/2, kdSide, kdSide); qContext.fill(); } } } Count() { return this.#mqaPoints.length; } Points() { return this.#mqaPoints; } P(i) { return this.#mqaPoints[i]; } Entropy(iaClassSizes) { // Count number of each type iaClassSizes[0] = 0; iaClassSizes[1] = 0; for (let i = 0; i < this.#mqaPoints.length; ++i) { let qP = this.#mqaPoints[i]; if (qP.miClass == 0) { ++iaClassSizes[0]; } else { ++iaClassSizes[1]; } } return Entropy(iaClassSizes[0], iaClassSizes[1]); } // Create a sorted list of the data based on coordinate values CreateSortedList(iCoord) { let qaP = this.Points(); let qaSort = new Array(qaP.length); let qSwap = null; for (let i = 0; i < qaP.length; ++i) { // Insert the next point at the last open position qaSort[i] = qaP[i]; let j = i; // Move the point up the list until it is sorted while (j > 0 && qaSort[j].mdX[iCoord] < qaSort[j - 1].mdX[iCoord]) { qSwap = qaSort[j - 1]; qaSort[j - 1] = qaSort[j]; qaSort[j] = qSwap; --j; } } return qaSort; } FindOptimalSplit() { let qaP = this.Points(); // Take the middle value between each successive point values let iMinEntropyCoord = 0; // This designates the index that we split after. So, 0 is a split between 0 and 1. let iMinSplitIndex = -1; // There are two classes: 0 and 1 let iaInitialCount = [0,0]; // Start with the group entropy let dMinEntropy = this.Entropy(iaInitialCount); // The counts are [group#][class] let iaaCounts = [[0,0],[0,0]]; // Put the sorts into arrays let qaaSorts = [0,0]; qaaSorts[0] = this.CreateSortedList(0); qaaSorts[1] = this.CreateSortedList(1); for (let iCoord = 0; iCoord < 2; ++iCoord) { // Everything begins in the second group iaaCounts[0][0] = 0; iaaCounts[0][1] = 0; iaaCounts[1][0] = iaInitialCount[0]; iaaCounts[1][1] = iaInitialCount[1]; // The initial entropy is group entropy value for (let iSplit = 0; iSplit < qaP.length - 1; ++iSplit) { // Move the first point from group 1 to group 0 if (qaaSorts[iCoord][iSplit].miClass == 0) { iaaCounts[0][0] += 1; iaaCounts[1][0] -= 1; } else { iaaCounts[0][1] += 1; iaaCounts[1][1] -= 1; } let dCurrEntropy = TotalSplitEntropy(iaaCounts); if (dCurrEntropy < dMinEntropy) { iMinEntropyCoord = iCoord; iMinSplitIndex = iSplit; dMinEntropy = dCurrEntropy; } } } const kiCoord = iMinEntropyCoord; let dValue = 0; let qaLowPoints = new Array(iMinSplitIndex); let qaHighPoints = new Array(qaaSorts[kiCoord].length - iMinSplitIndex - 1); for (let i = 0; i < qaaSorts[kiCoord].length; ++i) { if (i <= iMinSplitIndex) { qaLowPoints[i] = qaaSorts[kiCoord][i]; } else { qaHighPoints[i - iMinSplitIndex - 1] = qaaSorts[kiCoord][i]; } } // Use the average between values if (kiCoord == 0) { dValue = (qaaSorts[kiCoord][iMinSplitIndex].mdX[0] + qaaSorts[kiCoord][iMinSplitIndex + 1].mdX[0])/2 } else { dValue = (qaaSorts[kiCoord][iMinSplitIndex].mdX[1] + qaaSorts[kiCoord][iMinSplitIndex + 1].mdX[1])/2 } return new CSplit(kiCoord, dValue, new CDataSet(qaLowPoints), new CDataSet(qaHighPoints)); } static GenerateUniformSet(iCount) { let qaPoints = new Array(iCount); for (let i = 0; i < iCount; ++i) { qaPoints[i] = CDataPoint.Uniform(); } return qaPoints; } } // This is decision boundary split for the classification tree. // A split contains a variable coordinate index of 0 or 1 for the first or second variable. // It also contains a boundary value for that coordinate in the interval [0, 1]. class CSplit { #miCoord; #mdValue; #mqLowSet; #mqHighSet; constructor(iCoord, dValue, qLowSet, qHighSet) { this.#miCoord = iCoord; this.#mdValue = dValue; this.#mqLowSet = qLowSet; this.#mqHighSet = qHighSet; } get miC() { return this.#miCoord; } get mdV() { return this.#mdValue; } get mqLowSet() { return this.#mqLowSet; } get mqHighSet() { return this.#mqHighSet; } } // A decision node contains split on the dataset. // It also has two child pointers to potential child nodes. class CDecisionNode { #mqSplit; #mqLowNode = null; #mqHighNode = null; constructor(qDataSet) { this.#mqSplit = qDataSet.FindOptimalSplit(); if (this.#mqSplit.mqLowSet.Count() > 20) { this.#mqLowNode = new CDecisionNode(this.#mqSplit.mqLowSet); } if (this.#mqSplit.mqHighSet.Count() > 20) { this.#mqHighNode = new CDecisionNode(this.#mqSplit.mqHighSet); } } get miC() { return this.#mqSplit.miC; } get mdV() { return this.#mqSplit.mdV; } Graph(qContext, qRect, dAlpha) { let iWidth = qContext.canvas.width; let iHeight = qContext.canvas.height; qContext.strokeStyle = `rgb(0 0 0 / ${dAlpha})`;; qContext.lineWidth = 1; qContext.beginPath(); if (this.miC == 0) { let dX = this.mdV*iWidth; qContext.moveTo(dX, qRect.mdLowY*iHeight); qContext.lineTo(dX, qRect.mdHighY*iHeight); } else { let dY = this.mdV*iHeight; qContext.moveTo(qRect.mdLowX*iWidth, dY); qContext.lineTo(qRect.mdHighX*iWidth, dY); } qContext.stroke(); let qaSplitRects = qRect.Split(this.#mqSplit); if (this.#mqLowNode != null) { this.#mqLowNode.Graph(qContext, qaSplitRects[0], 3*dAlpha/4); } if (this.#mqHighNode != null) { this.#mqHighNode.Graph(qContext, qaSplitRects[1], 3*dAlpha/4); } } } // This represents a region of space for the decision node that is used for rendering the boundary. class CRectangle { #mdaLow = [0,0]; #mdaHigh = [0,0]; constructor(daLow, daHigh) { this.#mdaLow[0] = daLow[0]; this.#mdaLow[1] = daLow[1]; this.#mdaHigh[0] = daHigh[0]; this.#mdaHigh[1] = daHigh[1]; } Split(qSplit) { let qaRects = new Array(2); if (qSplit.miC == 0) { qaRects[0] = new CRectangle([this.#mdaLow[0], this.#mdaLow[1]],[qSplit.mdV, this.#mdaHigh[1]]); qaRects[1] = new CRectangle([qSplit.mdV, this.#mdaLow[1]],[this.#mdaHigh[0], this.#mdaHigh[1]]); } else { qaRects[0] = new CRectangle([this.#mdaLow[0], this.#mdaLow[1]],[this.#mdaHigh[0], qSplit.mdV]); qaRects[1] = new CRectangle([this.#mdaLow[0], qSplit.mdV],[this.#mdaHigh[0], this.#mdaHigh[1]]); } return qaRects; } get mdLowX() { return this.#mdaLow[0]; } get mdLowY() { return this.#mdaLow[1]; } get mdHighX() { return this.#mdaHigh[0]; } get mdHighY() { return this.#mdaHigh[1]; } } function Entropy(iCount0, iCount1) { let dP0 = iCount0/(iCount0 + iCount1); let dP1 = iCount1/(iCount0 + iCount1); if (dP0 == 1 || dP1 == 1) { return 0; } // Use the log base 2 let dEntropy = -(dP0*Math.log(dP0) + dP1*Math.log(dP1))/Math.log(2); return dEntropy; } function TotalSplitEntropy(iaaCounts) { let iTotal0 = iaaCounts[0][0] + iaaCounts[0][1]; let iTotal1 = iaaCounts[1][0] + iaaCounts[1][1]; let iTotal = iTotal0 + iTotal1; let dPG0 = iTotal0/iTotal; let dPG1 = iTotal1/iTotal; let dEntropy = dPG0*Entropy(iaaCounts[0][0], iaaCounts[0][1]) + dPG1*Entropy(iaaCounts[1][0], iaaCounts[1][1]); return dEntropy; } // This is a classification tree. // The root node specifies the main split with a variable index (0 or 1) and a value in the interval [0, 1] class CClassificationTree { #mqDataSet = null; #mqRootNode = null; constructor(iCount) { this.#mqDataSet = new CDataSet(CDataSet.GenerateUniformSet(iCount)); this.#mqRootNode = new CDecisionNode(this.#mqDataSet); } Graph() { var qCanvas = document.getElementById("idCanvas"); var qContext2D = qCanvas.getContext("2d"); qContext2D.transform(1, 0, 0, -1, 0, qContext2D.canvas.height); this.#mqDataSet.Graph(qContext2D); this.#mqRootNode.Graph(qContext2D, new CRectangle([0,0],[1,1]), 1.0); } } function Initialize() { let qClassificationTree = new CClassificationTree(50); qClassificationTree.Graph(); }
© 20072025 XoaX.net LLC. All rights reserved.