import { ethers } from 'hardhat'; import { SumtreeLibrary, TestSumtree } from '../src/contracts'; import { randomBytes, hexlify, getAddress, AddressLike } from 'ethers'; import { expect } from 'chai'; describe('Sumtree', () => { let tst: TestSumtree; let sumtreeLibrary: SumtreeLibrary; before(async () => { const stf = await ethers.getContractFactory('SumtreeLibrary'); sumtreeLibrary = await stf.deploy(); await sumtreeLibrary.waitForDeployment(); }) beforeEach(async () => { const f = await ethers.getContractFactory('TestSumtree', { libraries: { 'SumtreeLibrary': await sumtreeLibrary.getAddress() } }); tst = await f.deploy() await tst.waitForDeployment(); }); interface ValidationResult { isValid: boolean; actualSum?: bigint; actualCount?: bigint; } async function printNode(nodeId: bigint, indent:number = 0, prefix:string='') { const n = await tst.node(nodeId); console.log(' '.repeat(indent), prefix, `id=${nodeId} s=${n.sum} v=${n.value} c=${n.count} k=${n.key} p=${n.parent}`); if( n.left != 0n ) { await printNode(n.left, indent + 2, 'L'); } if( n.right != 0n ) { await printNode(n.right, indent + 2, 'R'); } } async function validateNode(nodeId: bigint): Promise { if (nodeId === 0n) { return { isValid: true, actualSum: BigInt(0), actualCount: 0n }; } const node = await tst.node(nodeId); // Validate left subtree if (node.left !== 0n) { const leftNode = await tst.node(node.left); // Check ordering if (leftNode.value > node.value || (leftNode.value === node.value && BigInt(leftNode.key) >= BigInt(node.key))) { console.error(`Order violation at node ${nodeId} with left child ${node.left}`); console.error(`Parent: (${node.value.toString()}, ${node.key})`); console.error(`Left: (${leftNode.value.toString()}, ${leftNode.key})`); return { isValid: false }; } // Check parent pointer if (leftNode.parent !== nodeId) { console.error(`Parent pointer mismatch: node ${node.left} should point to ${nodeId}`); return { isValid: false }; } } // Validate right subtree if (node.right !== 0n) { const rightNode = await tst.node(node.right); // Check ordering if (rightNode.value < node.value || (rightNode.value === node.value && BigInt(rightNode.key) <= BigInt(node.key))) { console.error(`Order violation at node ${nodeId} with right child ${node.right}`); console.error(`Parent: (${node.value.toString()}, ${node.key})`); console.error(`Right: (${rightNode.value.toString()}, ${rightNode.key})`); return { isValid: false }; } // Check parent pointer if (rightNode.parent !== nodeId) { console.error(`Parent pointer mismatch: node ${node.right} should point to ${nodeId}`); return { isValid: false }; } } // Recursively validate children and get their sums and counts const leftResult = await validateNode(node.left); if (!leftResult.isValid) return { isValid: false }; const rightResult = await validateNode(node.right); if (!rightResult.isValid) return { isValid: false }; // Calculate actual sum and count const actualSum = leftResult.actualSum! + node.value + rightResult.actualSum!; const actualCount = leftResult.actualCount! + 1n + rightResult.actualCount!; // Validate sum if (node.sum !== actualSum) { console.error(`Sum mismatch at node ${nodeId}:`); console.error(`Expected: ${node.sum.toString()}`); console.error(`Actual: ${actualSum.toString()}`); return { isValid: false }; } // Validate count if (node.count !== actualCount) { console.error(`Count mismatch at node ${nodeId}:`); console.error(`Expected: ${node.count}`); console.error(`Actual: ${actualCount}`); return { isValid: false }; } return { isValid: true, actualSum, actualCount }; } async function validateTree(root: bigint): Promise { try { const result = await validateNode(root); if (result.isValid) { //console.log("Tree is valid:"); //console.log(`Total sum: ${result.actualSum!.toString()}`); //console.log(`Total count: ${result.actualCount}`); } return result.isValid; } catch (error) { //console.error("Error validating tree:", error); return false; } } interface NodeValue { value: bigint; key: string; } async function treeToSortedList(): Promise { const values: NodeValue[] = []; const rootId = await tst.root(); async function inorderTraversal(nodeId: bigint): Promise { if (nodeId === 0n) return; const node = await tst.node(nodeId); // Traverse left if (node.left !== 0n) { await inorderTraversal(node.left); } // Add current node values.push({ value: node.value, key: node.key }); // Traverse right if (node.right !== 0n) { await inorderTraversal(node.right); } } await inorderTraversal(rootId); return values; } it('Simple left rotation', async () => { // Tree becomes: // 100 200 // \ -> / \ // 200 100 300 // \ // 300 // Insert in increasing value order const pairs: [bigint,AddressLike][] = [ [100n, getAddress('0x' + '1'.repeat(40))], // id=1 [200n, getAddress('0x' + '2'.repeat(40))], // id=2 [300n, getAddress('0x' + '3'.repeat(40))], // id=3 ]; for( const [w, a] of pairs ) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } //await printNode(await tst.root()); expect(await tst.root()).eq(2n); expect((await tst.node(2n)).left).eq(1n); expect((await tst.node(2n)).right).eq(3n); }); it('Simple right rotation', async () => { // Tree becomes: // 300 200 // / -> / \ // 200 100 300 // / //100 // Insert in decreasing value order const pairs: [bigint,AddressLike][] = [ [300n, getAddress('0x' + '4'.repeat(40))], // id=1 [200n, getAddress('0x' + '5'.repeat(40))], // id=2 [100n, getAddress('0x' + '6'.repeat(40))], // id=3 ]; for( const [w, a] of pairs ) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } //await printNode(await tst.root()); expect(await tst.root()).eq(2n); expect((await tst.node(2n)).left).eq(3n); expect((await tst.node(2n)).right).eq(1n); }); it('Right-right case (double rotation)', async () => { // Tree becomes: // 100 100 200 // \ -> \ -> / \ // 300 200 100 300 // / \ // 200 300 const pairs: [bigint,AddressLike][] = [ [100n, getAddress('0x' + '7'.repeat(40))], // id=1 [300n, getAddress('0x' + '8'.repeat(40))], // id=2 [200n, getAddress('0x' + '9'.repeat(40))], // id=3 ]; for( const [w, a] of pairs ) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } //await printNode(await tst.root()); expect(await tst.root()).eq(3n); expect((await tst.node(3n)).left).eq(1n); expect((await tst.node(3n)).right).eq(2n); }); it('Left-right case (double rotation)', async () => { // Tree becomes: // 300 300 200 // / -> / -> / \ // 100 200 100 300 // \ / // 200 100 const pairs: [bigint,AddressLike][] = [ [300n, getAddress('0x' + '7'.repeat(40))], // id=1 [100n, getAddress('0x' + '8'.repeat(40))], // id=2 [200n, getAddress('0x' + '9'.repeat(40))], // id=3 ]; for( const [w, a] of pairs ) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } //await printNode(await tst.root()); expect(await tst.root()).eq(3n); expect((await tst.node(3n)).left).eq(2n); expect((await tst.node(3n)).right).eq(1n); }); it('Key-based ordering (same values)', async () => { const pairs: [bigint,AddressLike][] = [ [100n, getAddress('0x' + '1'.repeat(40))], // id=1 [100n, getAddress('0x' + '2'.repeat(40))], // id=2 [100n, getAddress('0x' + '3'.repeat(40))], // id=3 ]; for( const [w, a] of pairs ) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } //await printNode(await tst.root()); expect(await tst.root()).eq(2n); expect((await tst.node(2n)).left).eq(1n); expect((await tst.node(2n)).right).eq(3n); }); it('Remove leaf node from balanced tree', async () => { // Tree structure: // 200 // / \ // 100 300 // // After removing 100: // 300 // / // 200 const pairs: [bigint,AddressLike][] = [ [200n, getAddress('0x' + '1'.repeat(40))], // id=1 [100n, getAddress('0x' + '2'.repeat(40))], // id=2 [300n, getAddress('0x' + '3'.repeat(40))], // id=3 ]; // Build initial tree for (const [w, a] of pairs) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } //await printNode(await tst.root()); // Remove leaf node 100 const tx = await tst.remove(pairs[1][1]); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); //await printNode(await tst.root()); // Verify structure const rootNodeId = await tst.root() expect(rootNodeId).eq(3n); expect((await tst.node(rootNodeId)).left).eq(1n); expect((await tst.node(rootNodeId)).right).eq(0n); }); it('Remove node with one child', async () => { // Initial tree: // 200 // / \ // 100 300 // / // 50 // // After removing 100: // 200 // / \ // 50 300 const pairs: [bigint,AddressLike][] = [ [200n, getAddress('0x' + '1'.repeat(40))], // id=1 [100n, getAddress('0x' + '2'.repeat(40))], // id=2 [300n, getAddress('0x' + '3'.repeat(40))], // id=3 [50n, getAddress('0x' + '4'.repeat(40))], // id=4 ]; // Build initial tree for (const [w, a] of pairs) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } // Remove node 100 const tx = await tst.remove(pairs[1][1]); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); // Verify structure expect(await tst.root()).eq(1n); expect((await tst.node(1n)).left).eq(4n); expect((await tst.node(1n)).right).eq(3n); }); it('Remove node with two children', async () => { // Initial tree: // 200 // / \ // 100 300 // / \ // 50 150 // // After removing 100: // 200 // / \ // 50 300 // / // 150 const pairs: [bigint,AddressLike][] = [ [200n, getAddress('0x' + '1'.repeat(40))], // id=1 [100n, getAddress('0x' + '2'.repeat(40))], // id=2 [300n, getAddress('0x' + '3'.repeat(40))], // id=3 [50n, getAddress('0x' + '4'.repeat(40))], // id=4 [150n, getAddress('0x' + '5'.repeat(40))], // id=5 ]; // Build initial tree for (const [w, a] of pairs) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } //await printNode(await tst.root()); // Remove node 100 const tx = await tst.remove(pairs[1][1]); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); /* await printNode(await tst.root()); // Verify structure expect(await tst.root()).eq(1n); expect((await tst.node(1n)).left).eq(5n); expect((await tst.node(1n)).right).eq(3n); expect((await tst.node(5n)).left).eq(4n); */ }); it('Remove root node triggers rebalance', async () => { // Initial tree: // 200 // / \ // 100 300 // / \ \ // 50 150 400 // // After removing 200: // 300 // / \ // 100 400 // / \ // 50 150 const pairs: [bigint,AddressLike][] = [ [200n, getAddress('0x' + '1'.repeat(40))], // id=1 [100n, getAddress('0x' + '2'.repeat(40))], // id=2 [300n, getAddress('0x' + '3'.repeat(40))], // id=3 [50n, getAddress('0x' + '4'.repeat(40))], // id=4 [150n, getAddress('0x' + '5'.repeat(40))], // id=5 [400n, getAddress('0x' + '6'.repeat(40))], // id=6 ]; // Build initial tree for (const [w, a] of pairs) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } // Remove root node const tx = await tst.remove(pairs[0][1]); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); // Verify structure expect(await tst.root()).eq(3n); expect((await tst.node(3n)).left).eq(2n); expect((await tst.node(3n)).right).eq(6n); expect((await tst.node(2n)).left).eq(4n); expect((await tst.node(2n)).right).eq(5n); }); it('Insert after removal maintains valid tree', async () => { // Initial tree: // 200 // / \ // 100 300 // / \ // 50 150 const pairs: [bigint,AddressLike][] = [ [200n, getAddress('0x' + '1'.repeat(40))], // id=1 [100n, getAddress('0x' + '2'.repeat(40))], // id=2 [300n, getAddress('0x' + '3'.repeat(40))], // id=3 [50n, getAddress('0x' + '4'.repeat(40))], // id=4 [150n, getAddress('0x' + '5'.repeat(40))], // id=5 ]; // Build initial tree for (const [w, a] of pairs) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } // Remove node with value 100 const tx = await tst.remove(pairs[1][1]); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); // Insert new nodes to test different scenarios const newPairs: [bigint,AddressLike][] = [ [125n, getAddress('0x' + '6'.repeat(40))], // Between 50 and 150 [175n, getAddress('0x' + '7'.repeat(40))], // Between 150 and 200 [250n, getAddress('0x' + '8'.repeat(40))], // Between 200 and 300 ]; for (const [w, a] of newPairs) { const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); } // Verify final tree structure maintains all invariants expect(await validateTree(await tst.root())).eq(true); // Optional: Print final tree state for debugging //await printNode(await tst.root()); }); it('Remove should fail for non-existent key', async () => { const nonExistentAddress = getAddress('0x' + 'f'.repeat(40)); await expect(tst.remove(nonExistentAddress)) .to.be.revertedWithCustomError(sumtreeLibrary, 'KeyNotFound') .withArgs(nonExistentAddress); }); it('Verify sum updates after removal', async () => { // Add three nodes const pairs: [bigint,AddressLike][] = [ [200n, getAddress('0x' + '1'.repeat(40))], [100n, getAddress('0x' + '2'.repeat(40))], [300n, getAddress('0x' + '3'.repeat(40))], ]; for (const [w, a] of pairs) { const tx = await tst.add(a, w); await tx.wait(); } // Get initial sum const initialSum = (await tst.node(await tst.root())).sum; expect(initialSum).eq(600n); // 200 + 100 + 300 // Remove middle node (100) const tx = await tst.remove(pairs[1][1]); await tx.wait(); // Verify sum is updated const finalSum = (await tst.node(await tst.root())).sum; expect(finalSum).eq(500n); // 200 + 300 }); function validateOrdering(list: NodeValue[]): { isValid: boolean; violations: string[] } { const violations: string[] = []; for (let i = 0; i < list.length - 1; i++) { const current = list[i]; const next = list[i + 1]; // Check if next is greater than current if (next.value < current.value || (next.value === current.value && next.key.toLowerCase() <= current.key.toLowerCase())) { violations.push( `Ordering violation at index ${i}:\n` + `Current: (value: ${current.value.toString()}, key: ${current.key})\n` + `Next: (value: ${next.value.toString()}, key: ${next.key})` ); } } return { isValid: violations.length === 0, violations }; } it('Random insert & verify pick', async () => { for( let i = 0; i < 100; i++ ) { const a = getAddress(hexlify(randomBytes(20))); const w = 1 + Math.floor(Math.random() * 1000); const tx = await tst.add(a, w); await tx.wait(); expect(await validateTree(await tst.root())).eq(true); const x = await treeToSortedList(); const y = validateOrdering(x); if(! y.isValid ) { console.log(x); console.log(y.violations); } expect(y.isValid).eq(true); } //await printNode(await tst.root()); const x = await treeToSortedList(); const y = validateOrdering(x); expect(y.isValid).eq(true); let total: bigint = 0n; for( const z of x ) { const totalBefore = total; total += z.value; // If we pick the beginning & end of range we get the expected const a = await tst.pick(totalBefore); const b = await tst.pick(total - 1n); expect(a).eq(b); expect(b).eq(z.key); // Somewhere in the middle const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n)); expect(c).eq(z.key); } }); it('Random insert, remove & verify pick', async () => { // Keep track of all addresses and their current presence in tree const addressMap = new Map(); const existingAddresses: string[] = []; let numNodes = 0; for (let i = 0; i < 200; i++) { // 70% chance to add, 30% chance to remove when we have nodes const shouldAdd = numNodes === 0 || Math.random() < 0.7; if (shouldAdd) { // Add new node const a = getAddress(hexlify(randomBytes(20))); const w = 1 + Math.floor(Math.random() * 1000); const tx = await tst.add(a, w); const receipt = await tx.wait(); //console.log('Insert cost', receipt?.cumulativeGasUsed, numNodes); addressMap.set(a.toLowerCase(), true); existingAddresses.push(a); numNodes++; // Verify tree is valid after addition expect(await validateTree(await tst.root())).eq(true); } else { // Remove random existing node const indexToRemove = Math.floor(Math.random() * existingAddresses.length); const addressToRemove = existingAddresses[indexToRemove]; if (addressMap.get(addressToRemove.toLowerCase())) { const tx = await tst.remove(addressToRemove); const receipt = await tx.wait(); //console.log('Remove cost', receipt?.cumulativeGasUsed, numNodes); addressMap.set(addressToRemove.toLowerCase(), false); numNodes--; // Verify tree is valid after removal expect(await validateTree(await tst.root())).eq(true); } } // Verify ordering after each operation const x = await treeToSortedList(); const y = validateOrdering(x); if (!y.isValid) { console.log("Operation:", i); console.log("Current tree:", x); console.log("Violations:", y.violations); } expect(y.isValid).eq(true); // Every 10 operations, verify pick functionality for all nodes if (i % 10 === 0) { const x = await treeToSortedList(); let total: bigint = 0n; for (const z of x) { const totalBefore = total; total += z.value; // Verify picks at range boundaries const a = await tst.pick(totalBefore); const b = await tst.pick(total - 1n); expect(a).eq(b); expect(b).eq(z.key); // Verify pick in middle of range const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n)); expect(c).eq(z.key); } } } // Final verification of entire tree const finalList = await treeToSortedList(); expect(validateOrdering(finalList).isValid).eq(true); // Verify final tree size matches our tracking expect(finalList.length).eq(numNodes); // Final verification of pick functionality for all nodes let total: bigint = 0n; for (const z of finalList) { const totalBefore = total; total += z.value; const a = await tst.pick(totalBefore); const b = await tst.pick(total - 1n); expect(a).eq(b); expect(b).eq(z.key); const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n)); expect(c).eq(z.key); } }); });