Skip to content

Instantly share code, notes, and snippets.

@HarryR
Created January 8, 2025 10:32
Show Gist options
  • Select an option

  • Save HarryR/fd2339e6d321c49590bca371f8f7abea to your computer and use it in GitHub Desktop.

Select an option

Save HarryR/fd2339e6d321c49590bca371f8f7abea to your computer and use it in GitHub Desktop.

Revisions

  1. HarryR created this gist Jan 8, 2025.
    115 changes: 115 additions & 0 deletions Distributor.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,115 @@
    // SPDX-License-Identifier: AGPL-3.0-only

    pragma solidity ^0.8.0;

    import { IERC721 } from '@openzeppelin/contracts/token/ERC721/IERC721.sol';
    import { IERC721Receiver } from '@openzeppelin/contracts/token/ERC721/IERC721Receiver.sol';

    import { Staking } from './Staking.sol';
    import { randomBytes32 } from './Random.sol';
    import { Moderation } from './Moderation.sol';

    contract Distributor is IERC721Receiver
    {
    struct Item {
    IERC721 nft;
    uint256 tokenId;
    }

    struct Win {
    Item item;
    address who;
    }

    Item[] private items;

    Win[] private winners;

    uint256 private winBlock;

    Staking public immutable staker;

    Moderation public immutable moderation;

    constructor(
    Staking in_staking,
    Moderation in_moderation
    ) {
    staker = in_staking;

    moderation = in_moderation;
    }

    /// Moderators (allowed posters) can to trigger the distribution function
    function tick()
    external
    {
    if( ! moderation.isModerator(msg.sender) )
    {
    revert NotAllowed();
    }

    internal_distribute();
    }

    function internal_distribute()
    internal
    {
    // Transfer previous winners whenever called in a future block
    if( winners.length > 0 && winBlock != block.number )
    {
    winBlock = block.number;

    while( winners.length > 0 )
    {
    Win memory x = winners[winners.length - 1];

    winners.pop();

    x.item.nft.safeTransferFrom(address(this), x.who, x.item.tokenId);
    }
    }

    // Select a random winner & push into the stack (distributed above)
    if( staker.getSum() > 0 && items.length > moderation.poolSize() )
    {
    uint x = uint(randomBytes32()) % items.length;

    Item memory y = items[x];

    if( x != (items.length - 1) )
    {
    items[x] = items[items.length - 1];
    }

    items.pop();

    address winner = staker.random(randomBytes32());

    winners.push(Win({item: y, who: winner}));
    }
    }

    error NotAllowed();

    function onERC721Received(
    address operator,
    address from,
    uint256 tokenId,
    bytes calldata /* data */
    )
    external
    returns (bytes4)
    {
    if( ! moderation.isAllowed(msg.sender, operator, from) )
    {
    revert NotAllowed();
    }

    items.push(Item(IERC721(msg.sender), tokenId));

    internal_distribute();

    return IERC721Receiver.onERC721Received.selector;
    }
    }
    58 changes: 58 additions & 0 deletions FrontendUtils.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,58 @@
    // SPDX-License-Identifier: AGPL-3.0-only

    pragma solidity ^0.8.0;

    import { IERC20Metadata } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol";

    import { Staking, ValueField } from './Staking.sol';
    import { Moderation } from './Moderation.sol';
    import { Distributor } from './Distributor.sol';

    contract FrontendUtils {
    struct UserOverview {
    bool isOwner;
    bool isModerator;
    uint userStakedAmount;
    uint stakedTotal;
    uint stakersCount;
    uint userTokenBalance;
    string tokenSymbol;
    string tokenName;
    uint tokenSupply;
    uint tokenDecimals;
    }

    Moderation public immutable moderation;
    Distributor public immutable distributor;
    Staking public immutable staking;

    constructor(
    Moderation in_moderation,
    Distributor in_distributor,
    Staking in_staking
    ) {
    moderation = in_moderation;
    distributor = in_distributor;
    staking = in_staking;
    }

    function getOverviewForUser(address who)
    external view
    returns (UserOverview memory)
    {
    IERC20Metadata token = IERC20Metadata(address(staking.stakingToken()));

    return UserOverview({
    isOwner: moderation.owner() == who,
    isModerator: moderation.isModerator(who),
    userStakedAmount: ValueField.unwrap(staking.getValue(who)),
    stakedTotal: staking.getSum(),
    stakersCount: staking.getCount(),
    userTokenBalance: token.balanceOf(who),
    tokenSymbol: token.symbol(),
    tokenSupply: token.totalSupply(),
    tokenName: token.name(),
    tokenDecimals: token.decimals()
    });
    }
    }
    26 changes: 26 additions & 0 deletions MockNFT.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,26 @@
    // SPDX-License-Identifier: MIT

    pragma solidity ^0.8.0;

    import { ERC721 } from "@openzeppelin/contracts/token/ERC721/ERC721.sol";

    contract MockNFT is ERC721 {
    constructor()
    ERC721("Mock NFT", "MOCK")
    { }

    uint public tokenIdCounter;

    function mint(address to, uint n)
    external
    {
    for( uint i = 0; i < n; i++ )
    {
    uint tokenId = tokenIdCounter;

    tokenIdCounter += 1;

    _safeMint(to, tokenId);
    }
    }
    }
    19 changes: 19 additions & 0 deletions MockToken.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,19 @@
    // SPDX-License-Identifier: MIT

    pragma solidity ^0.8.0;

    import { ERC20 } from "@openzeppelin/contracts/token/ERC20/ERC20.sol";

    contract MockToken is ERC20 {
    constructor()
    ERC20("Mock NFT", "MOCK")
    { }

    uint public tokenIdCounter;

    function mint(address to, uint256 amount)
    external
    {
    _mint(to, amount);
    }
    }
    115 changes: 115 additions & 0 deletions Moderation.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,115 @@
    // SPDX-License-Identifier: AGPL-3.0-only

    pragma solidity ^0.8.0;

    import { Ownable } from '@openzeppelin/contracts/access/Ownable.sol';
    import { EnumerableSet } from '@openzeppelin/contracts/utils/structs/EnumerableSet.sol';

    contract Moderation is Ownable {

    using EnumerableSet for EnumerableSet.AddressSet;

    EnumerableSet.AddressSet private allowedNFTs;

    EnumerableSet.AddressSet private allowedPosters;

    uint public poolSize;

    constructor (uint in_poolSize)
    Ownable()
    {
    poolSize = in_poolSize;
    }

    /// Retrieve list of allowed NFT collections
    function getNFTList()
    external view
    returns (address[] memory out_addrs)
    {
    uint n = allowedNFTs.length();

    out_addrs = new address[](n);

    for( uint i = 0; i < n; i++ )
    {
    out_addrs[i] = allowedNFTs.at(i);
    }
    }

    /// Retrieve list of moderators
    function getModeratorList()
    external view
    returns (address[] memory out_addrs)
    {
    uint n = allowedPosters.length();

    out_addrs = new address[](n);

    for( uint i = 0; i < n; i++ )
    {
    out_addrs[i] = allowedPosters.at(i);
    }
    }

    function setPoolSize(uint newPoolSize)
    external onlyOwner
    {
    poolSize = newPoolSize;
    }

    function modifyAllowedNFTs(address[] calldata in_nfts, bool state)
    external onlyOwner
    {
    for( uint i = 0; i < in_nfts.length; i++ )
    {
    address x = in_nfts[i];

    if( state == true && ! allowedNFTs.contains(x) )
    {
    allowedNFTs.add(x);
    }
    else if( state == false && allowedNFTs.contains(x) )
    {
    allowedNFTs.remove(x);
    }
    }
    }

    function modifyModerators(address[] calldata in_addrs, bool state)
    external onlyOwner
    {
    for( uint i = 0; i < in_addrs.length; i++ )
    {
    address x = in_addrs[i];

    if( state == true && ! allowedPosters.contains(x) )
    {
    allowedPosters.add(x);
    }
    else if( state == false && allowedPosters.contains(x) )
    {
    allowedPosters.remove(x);
    }
    }
    }

    function isModerator(address who)
    external view
    returns (bool)
    {
    return allowedPosters.contains(who);
    }

    function isAllowed(address nft, address operator, address from)
    external view
    returns (bool)
    {
    if( false == allowedNFTs.contains(nft) )
    {
    return false;
    }

    return allowedPosters.length() == 0
    || (allowedPosters.contains(operator) || allowedPosters.contains(from));
    }
    }
    30 changes: 30 additions & 0 deletions Random.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,30 @@
    // SPDX-License-Identifier: AGPL-3.0-only

    pragma solidity ^0.8.0;

    error RandomBytesFailure();

    address constant RANDOM_PRECOMPILE = 0x0100000000000000000000000000000000000001;

    function randomBytes32()
    view returns (bytes32)
    {
    if( block.chainid == 1337 ) {
    return keccak256(abi.encodePacked(
    msg.sender,
    msg.value,
    block.number,
    block.timestamp,
    gasleft()
    ));
    }
    else {
    (bool success, bytes memory entropy) = RANDOM_PRECOMPILE.staticcall(
    abi.encode(32, "")
    );
    if( success != true ) {
    revert RandomBytesFailure();
    }
    return bytes32(entropy);
    }
    }
    139 changes: 139 additions & 0 deletions Staker.spec.ts
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,139 @@
    import { ethers } from 'hardhat';
    import { expect } from 'chai';
    import { Distributor, MockNFT, MockToken, Moderation, Staking } from '../src/contracts';
    import { HardhatEthersSigner } from '@nomicfoundation/hardhat-ethers/signers';

    describe('Staker', () => {
    let d: Distributor;
    let s: Staking;
    let moderation: Moderation;
    let mockNFT: MockNFT;
    let mockToken: MockToken;

    before(async () => {
    mockNFT = await (await ethers.getContractFactory('MockNFT')).deploy();

    mockToken = await (await ethers.getContractFactory('MockToken')).deploy();

    const modf = await ethers.getContractFactory('Moderation');
    const m = moderation = await modf.deploy(10);
    await m.waitForDeployment();

    const stf = await ethers.getContractFactory('SumtreeLibrary');
    const st = await stf.deploy();
    await st.waitForDeployment();

    const stakingFactory = await ethers.getContractFactory('Staking', {
    libraries: {
    'SumtreeLibrary': await st.getAddress()
    }
    });

    s = await stakingFactory.deploy(await mockToken.getAddress());

    const distributorFactory = await ethers.getContractFactory('Distributor');

    d = await distributorFactory.deploy(await s.getAddress(), await m.getAddress());

    await s.waitForDeployment();
    await d.waitForDeployment();

    //console.log('Deployed');
    });

    function calculateIterationCount(weights: Record<string,number>, sigmaThreshold: number = 5): number
    {
    // Get minimum weight probability
    const totalWeight = Object.values(weights).reduce((sum, w) => sum + w, 0);
    const minProb = Math.min(...Object.values(weights).map(w => w / totalWeight));
    //console.log(' - Total Weight', totalWeight, 'Min Prob', minProb);

    // Basic sample size calculation based on minimum probability
    // n * p * (1-p) needs to be large enough for normal approximation
    // and we want enough samples for reliable statistics
    const NORMAL_APPROXIMATION_THRESHOLD_SQUARED = 25
    return Math.ceil(NORMAL_APPROXIMATION_THRESHOLD_SQUARED * sigmaThreshold * sigmaThreshold / minProb);
    }

    function verifyDistribution(
    weightsMap: Record<string, number>,
    resultsMap: Record<string, bigint>,
    sigmaThreshold: number
    ) {
    const totalItems = Object.values(resultsMap)
    .reduce((sum, count) => sum + Number(count), 0);
    const totalWeight = Object.values(weightsMap)
    .reduce((sum, w) => sum + w, 0);

    for (const [address, weight] of Object.entries(weightsMap)) {
    const probability = weight / totalWeight;
    const expectedCount = totalItems * probability;

    // Calculate standard deviation
    const stdDev = Math.sqrt(totalItems * probability * (1 - probability));

    // Calculate acceptable range
    const margin = sigmaThreshold * stdDev;
    const range: [number, number] = [
    Math.floor(expectedCount - margin),
    Math.ceil(expectedCount + margin)
    ];

    const actualCount = Number(resultsMap[address] || 0n);
    const withinRange = actualCount >= range[0] && actualCount <= range[1];

    expect(withinRange).eq(true);
    }
    }

    it('Distribution', async () => {
    const nSigners = 5;
    const allSigners = await ethers.getSigners();
    if( allSigners.length < nSigners ) {
    throw new Error('Not enough signers!');
    }
    const signers: HardhatEthersSigner[] = [];
    for( let i = 0; i < nSigners; i++ ) {
    signers.push(allSigners[i]);
    }
    //const signers = allSigners.slice(nSigners);
    const weightsMap: Record<string,number> = {};

    // Distribute tokens to the signers, then have the signers stake them
    for( let i = 0; i < nSigners; i++ )
    {
    const w = Math.floor(1 + (Math.random() * 100));
    const x = signers[i];
    await (await mockToken.mint(x, w)).wait();
    const b = mockToken.connect(x);
    await (await b.approve(await s.getAddress(), w)).wait();
    (await s.connect(x).stake(w)).wait();
    weightsMap[x.address] = w;
    };

    // We must approve the NFTs before they can be used...
    await (await moderation.modifyAllowedNFTs([await mockNFT.getAddress()], true)).wait();

    // Mint in batches of 10, which will be randomly distributed to signers
    const batchCount = 10;
    const sigmaThreshold = 4;
    const m = calculateIterationCount(weightsMap, sigmaThreshold);
    //console.log(' - Iteration Count', m, batchCount, m/batchCount);
    for( let i = 0; i < Math.floor(m/batchCount); i++ )
    {
    const tx = await mockNFT.mint(await d.getAddress(), batchCount);
    await tx.wait();
    }

    // Collect the count of how many NFTs each account has
    const resultsMap: Record<string,bigint> = {};
    for( const s of signers )
    {
    const a = s.address;
    const b = await mockNFT.balanceOf(a);
    resultsMap[a] = b;
    }

    verifyDistribution(weightsMap, resultsMap, sigmaThreshold);
    });
    });
    88 changes: 88 additions & 0 deletions Staking.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,88 @@
    // SPDX-License-Identifier: AGPL-3.0-only

    pragma solidity ^0.8.0;

    import { IERC20Metadata } from "@openzeppelin/contracts/token/ERC20/extensions/IERC20Metadata.sol";

    import { Sumtree, SumtreeLibrary, ValueField } from './Sumtree.sol';

    contract Staking {
    IERC20Metadata public immutable stakingToken;

    Sumtree private tree;
    using SumtreeLibrary for Sumtree;

    constructor( address _stakingToken )
    {
    stakingToken = IERC20Metadata(_stakingToken);
    }

    error AlreadyStaked();
    error CannotStakeZero();
    error TransferFailed();
    error MustStakeWhole();

    function stake(ValueField amount)
    external
    {
    return stakeFor(msg.sender, amount);
    }

    function stakeFor(address who, ValueField amount)
    public
    {
    if( ValueField.unwrap(amount) == 0 ) revert CannotStakeZero();

    if( ValueField.unwrap(amount) % stakingToken.decimals() != 0 ) revert MustStakeWhole();

    if( tree.has(who) ) revert AlreadyStaked();

    bool ok = stakingToken.transferFrom(msg.sender, address(this), ValueField.unwrap(amount));

    if( false == ok ) revert TransferFailed();

    tree.add(who, amount);
    }

    function unstake()
    external
    {
    ValueField amount = tree.remove(msg.sender);

    if( ValueField.unwrap(amount) > 0 )
    {
    stakingToken.transfer(msg.sender, ValueField.unwrap(amount));
    }
    }

    function getValue(address staker)
    external view
    returns (ValueField)
    {
    if( tree.has(staker) ) {
    return tree.nodes[tree.nodesByKey[staker]].value;
    }
    return ValueField.wrap(0);
    }

    function getSum()
    external view
    returns (uint)
    {
    return tree.getSum();
    }

    function getCount()
    external view
    returns (uint)
    {
    return tree.getCount();
    }

    function random(bytes32 seed)
    external view
    returns (address)
    {
    return tree.random(seed);
    }
    }
    499 changes: 499 additions & 0 deletions Sumtree.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,499 @@
    // SPDX-License-Identifier: AGPL-3.0-only

    pragma solidity ^0.8.0;

    type SumField is uint104;
    type ValueField is uint88;
    type NodeIndex is uint16;

    struct Node {
    SumField sum;
    ValueField value;
    NodeIndex left;
    NodeIndex right;
    NodeIndex count;
    NodeIndex parent;

    address key;
    }

    using { NodeIndex_add as +, NodeIndex_neq as !=, NodeIndex_eq as ==, NodeIndex_lt as <, NodeIndex_gt as > } for NodeIndex global;
    using { ValueField_eq as ==, ValueField_lt as <, ValueField_gt as > } for ValueField global;
    using { SumField_add as +, SumField_sub as -, SumField_gte as >=, SumField_lt as < } for SumField global;

    function NodeIndex_add(NodeIndex a, NodeIndex b) pure returns (NodeIndex) {
    return NodeIndex.wrap(NodeIndex.unwrap(a) + NodeIndex.unwrap(b));
    }

    function NodeIndex_eq(NodeIndex a, NodeIndex b) pure returns (bool) {
    return NodeIndex.unwrap(a) == NodeIndex.unwrap(b);
    }

    function NodeIndex_neq(NodeIndex a, NodeIndex b) pure returns (bool) {
    return NodeIndex.unwrap(a) != NodeIndex.unwrap(b);
    }

    function NodeIndex_lt(NodeIndex a, NodeIndex b) pure returns (bool) {
    return NodeIndex.unwrap(a) < NodeIndex.unwrap(b);
    }

    function NodeIndex_gt(NodeIndex a, NodeIndex b) pure returns (bool) {
    return NodeIndex.unwrap(a) > NodeIndex.unwrap(b);
    }

    function ValueField_lt(ValueField a, ValueField b) pure returns (bool) {
    return ValueField.unwrap(a) < ValueField.unwrap(b);
    }

    function ValueField_gt(ValueField a, ValueField b) pure returns (bool) {
    return ValueField.unwrap(a) > ValueField.unwrap(b);
    }

    function ValueField_eq(ValueField a, ValueField b) pure returns (bool) {
    return ValueField.unwrap(a) == ValueField.unwrap(b);
    }

    function SumField_add(SumField a, SumField b) pure returns (SumField) {
    return SumField.wrap(SumField.unwrap(a) + SumField.unwrap(b));
    }

    function SumField_gte(SumField a, SumField b) pure returns (bool) {
    return SumField.unwrap(a) >= SumField.unwrap(b);
    }

    function SumField_sub(SumField a, SumField b) pure returns (SumField) {
    return SumField.wrap(SumField.unwrap(a) - SumField.unwrap(b));
    }

    function SumField_lt(SumField a, SumField b) pure returns (bool) {
    return SumField.unwrap(a) < SumField.unwrap(b);
    }

    function to_SumField(ValueField x) pure returns (SumField) {
    return SumField.wrap(ValueField.unwrap(x));
    }

    struct Sumtree {
    mapping(NodeIndex => Node) nodes;
    mapping(address => NodeIndex) nodesByKey; // Direct key to nodeId mapping
    NodeIndex nextNodeId;
    NodeIndex root_id;
    }

    library SumtreeLibrary {
    error DuplicateKey(address key);

    NodeIndex private constant EMPTY = NodeIndex.wrap(0);
    NodeIndex private constant NodeIndex_ONE = NodeIndex.wrap(1);
    SumField private constant SumField_ZERO = SumField.wrap(0);

    function add(Sumtree storage self, address key, ValueField value)
    public
    {
    if( has(self, key) ) {
    revert DuplicateKey(key);
    }

    Node memory new_node = Node({
    sum: to_SumField(value),
    value: value,
    left: EMPTY,
    right: EMPTY,
    parent: EMPTY,
    key: key,
    count: NodeIndex_ONE
    });

    NodeIndex new_id = self.nextNodeId = self.nextNodeId + NodeIndex_ONE;
    self.nodes[new_id] = new_node;
    self.nodesByKey[key] = new_id;

    if( self.root_id == EMPTY ) {
    self.root_id = new_id;
    return;
    }

    NodeIndex current_id = self.root_id;
    NodeIndex parent_id = EMPTY;

    while( current_id != EMPTY )
    {
    parent_id = current_id;
    Node storage currentNode = self.nodes[current_id];
    currentNode.sum = currentNode.sum + to_SumField(value);
    currentNode.count = currentNode.count + NodeIndex_ONE;

    if (value < currentNode.value ||
    (value == currentNode.value && uint160(key) < uint160(currentNode.key))) {
    current_id = currentNode.left;
    } else {
    current_id = currentNode.right;
    }
    }

    self.nodes[new_id].parent = parent_id;

    Node storage parentNode = self.nodes[parent_id];
    if (value < parentNode.value ||
    (value == parentNode.value && uint160(key) < uint160(parentNode.key))) {
    parentNode.left = new_id;
    } else {
    parentNode.right = new_id;
    }

    rebalance(self, parent_id);
    }

    // Right rotation
    function rotateRight(Sumtree storage tree, NodeIndex y)
    private
    returns (NodeIndex)
    {
    Node storage yNode = tree.nodes[y];

    NodeIndex x = tree.nodes[y].left;

    Node storage xNode = tree.nodes[x];
    NodeIndex T2 = xNode.right;

    // Update parent references
    NodeIndex yParent = tree.nodes[y].parent;

    xNode.parent = yParent;
    yNode.parent = x;
    if (T2 != EMPTY) {
    tree.nodes[T2].parent = y;
    }

    // Perform rotation
    xNode.right = y;
    yNode.left = T2;

    // Update counts and sums
    updateCountAndSum(tree, y);
    updateCountAndSum(tree, x);

    // Update parent's child reference
    if (yParent != EMPTY) {
    Node storage yParentNode = tree.nodes[yParent];
    if (yParentNode.left == y) {
    yParentNode.left = x;
    } else {
    yParentNode.right = x;
    }
    } else {
    tree.root_id = x;
    }

    return x;
    }

    // Left rotation
    function rotateLeft(Sumtree storage tree, NodeIndex x)
    private
    returns (NodeIndex)
    {
    Node storage xNode = tree.nodes[x];

    NodeIndex y = xNode.right;

    Node storage yNode = tree.nodes[y];
    NodeIndex T2 = yNode.left;

    // Update parent references
    NodeIndex xParent = xNode.parent;
    yNode.parent = xParent;
    xNode.parent = y;
    if (T2 != EMPTY) {
    tree.nodes[T2].parent = x;
    }

    // Perform rotation
    yNode.left = x;
    xNode.right = T2;

    // Update counts and sums
    updateCountAndSum(tree, x);
    updateCountAndSum(tree, y);

    // Update parent's child reference
    if (xParent != EMPTY) {
    Node storage xParentNode = tree.nodes[xParent];
    if (xParentNode.left == x) {
    xParentNode.left = y;
    } else {
    xParentNode.right = y;
    }
    } else {
    tree.root_id = y;
    }

    return y;
    }

    // Rebalance tree after insertion or deletion
    function rebalance(Sumtree storage tree, NodeIndex nodeId)
    private
    {
    NodeIndex current = nodeId;

    while (current != EMPTY) {
    updateCountAndSum(tree, current);

    int256 balance = getBalance(tree, current);

    // We consider a subtree imbalanced if the difference in count is more than 2x
    Node storage currentNode = tree.nodes[current];
    bool isImbalanced = balance > int256(uint256(NodeIndex.unwrap(getCount(tree, currentNode.right))) * 2) ||
    -balance > int256(uint256(NodeIndex.unwrap(getCount(tree, currentNode.left))) * 2);

    if (isImbalanced) {
    if (balance > 0) {
    // Left heavy
    if (getBalance(tree, currentNode.left) < 0) {
    // Left-Right case
    currentNode.left = rotateLeft(tree, currentNode.left);
    }
    current = rotateRight(tree, current);
    } else {
    // Right heavy
    if (getBalance(tree, currentNode.right) > 0) {
    // Right-Left case
    currentNode.right = rotateRight(tree, currentNode.right);
    }
    current = rotateLeft(tree, current);
    }
    }

    current = tree.nodes[current].parent;
    }
    }

    function getCount(Sumtree storage tree, NodeIndex nodeId) private view returns (NodeIndex) {
    return nodeId == EMPTY ? EMPTY : tree.nodes[nodeId].count;
    }

    function getBalance(Sumtree storage tree, NodeIndex nodeId) private view returns (int256) {
    if (nodeId == EMPTY) return 0;
    Node storage node = tree.nodes[nodeId];
    return int256(uint256(NodeIndex.unwrap(getCount(tree, node.left)))) -
    int256(uint256(NodeIndex.unwrap(getCount(tree, node.right))));
    }

    function updateCountAndSum(Sumtree storage tree, NodeIndex nodeId) private {
    if (nodeId == EMPTY) return;
    Node storage node = tree.nodes[nodeId];
    node.count = NodeIndex_ONE + getCount(tree, node.left) + getCount(tree, node.right);
    SumField leftSum = node.left != EMPTY ? tree.nodes[node.left].sum : SumField_ZERO;
    SumField rightSum = node.right != EMPTY ? tree.nodes[node.right].sum : SumField_ZERO;
    node.sum = leftSum + to_SumField(node.value) + rightSum;
    }

    error KeyNotFound(address key);

    function remove(Sumtree storage self, address key)
    public
    returns (ValueField value)
    {
    if (!has(self, key)) {
    revert KeyNotFound(key);
    }

    NodeIndex nodeId = self.nodesByKey[key];
    Node storage node = self.nodes[nodeId];
    value = node.value;

    // Case 1: Node has no children or one child
    NodeIndex replacementId;
    if (node.left == EMPTY) {
    replacementId = node.right;
    } else if (node.right == EMPTY) {
    replacementId = node.left;
    }
    // Case 2: Node has both children
    else {
    // Find successor (smallest value in right subtree)
    NodeIndex successorId = node.right;
    while (self.nodes[successorId].left != EMPTY) {
    successorId = self.nodes[successorId].left;
    }

    Node storage successorNode = self.nodes[successorId];

    // Store the successor's original position details
    NodeIndex successorParent = successorNode.parent;
    NodeIndex successorRight = successorNode.right;

    // If successor is not the immediate right child
    if (successorParent != nodeId) {
    // Replace successor with its right child
    if (successorRight != EMPTY) {
    self.nodes[successorRight].parent = successorParent;
    }
    self.nodes[successorParent].left = successorRight;

    // Set successor's right to node's right child
    successorNode.right = node.right;
    self.nodes[node.right].parent = successorId;
    }

    // Move node's left child to successor
    successorNode.left = node.left;
    self.nodes[node.left].parent = successorId;

    // Put successor in node's position
    successorNode.parent = node.parent;
    if (node.parent == EMPTY) {
    self.root_id = successorId;
    } else {
    Node storage parentNode = self.nodes[node.parent];
    if (parentNode.left == nodeId) {
    parentNode.left = successorId;
    } else {
    parentNode.right = successorId;
    }
    }

    // Start rebalancing from successor's original parent
    rebalance(self, successorParent != nodeId ? successorParent : successorId);

    delete self.nodes[nodeId];
    //delete self.nodesByKey[key];
    self.nodesByKey[key] = NodeIndex.wrap(0);
    return value;
    }

    // Handle Case 1 (no children or one child)
    NodeIndex parentId = node.parent;
    if (replacementId != EMPTY) {
    self.nodes[replacementId].parent = parentId;
    }

    if (parentId != EMPTY) {
    Node storage parentNode = self.nodes[parentId];
    if (parentNode.left == nodeId) {
    parentNode.left = replacementId;
    } else {
    parentNode.right = replacementId;
    }
    rebalance(self, parentId);
    } else {
    self.root_id = replacementId;
    }

    delete self.nodes[nodeId];
    //delete self.nodesByKey[key];
    self.nodesByKey[key] = NodeIndex.wrap(0);
    return value;
    }

    function replaceNode(Sumtree storage self, NodeIndex oldId, NodeIndex newId)
    private
    {
    NodeIndex parentId = self.nodes[oldId].parent;

    // Update parent's child pointer
    if (parentId != EMPTY) {
    Node storage parentNode = self.nodes[parentId];
    if (parentNode.left == oldId) {
    parentNode.left = newId;
    } else {
    parentNode.right = newId;
    }
    } else {
    self.root_id = newId;
    }

    // Update new node's parent pointer
    if (newId != EMPTY) {
    self.nodes[newId].parent = parentId;
    }
    }

    error EmptyTree();
    error RandomFailed();
    error OutOfRange();

    function random(Sumtree storage self, bytes32 seed)
    internal view
    returns (address)
    {
    uint seedHashed = uint(keccak256(abi.encodePacked(seed)));

    uint remaining_r = (seedHashed % getSum(self));

    return pick(self, remaining_r);
    }

    function pick(Sumtree storage self, uint random_value)
    internal view
    returns (address)
    {
    if( self.root_id == EMPTY ) {
    revert EmptyTree();
    }

    if( random_value >= getSum(self) ) {
    revert OutOfRange();
    }

    NodeIndex current_id = self.root_id;

    while( current_id != EMPTY )
    {
    Node storage current = self.nodes[current_id];

    SumField left_sum = SumField_ZERO;

    // Get sum of left subtree
    if( current.left != EMPTY ) {
    left_sum = self.nodes[current.left].sum;
    }

    // If random_value falls in left subtree
    if( random_value < SumField.unwrap(left_sum) )
    {
    current_id = current.left;
    continue;
    }

    // If random_value falls in current node's range
    if( random_value < SumField.unwrap(left_sum + to_SumField(current.value)) ) {
    return current.key;
    }

    random_value = random_value - SumField.unwrap((left_sum + to_SumField(current.value)));

    current_id = current.right;
    }

    revert RandomFailed();
    }

    function getCount(Sumtree storage self)
    internal view
    returns (uint)
    {
    if( self.root_id == EMPTY ) {
    return 0;
    }
    return NodeIndex.unwrap(self.nodes[self.root_id].count);
    }

    function getSum(Sumtree storage self)
    internal view
    returns (uint)
    {
    if( self.root_id == EMPTY ) {
    return 0;
    }
    return SumField.unwrap(self.nodes[self.root_id].sum);
    }

    function has(Sumtree storage self, address key)
    internal view
    returns (bool)
    {
    return self.nodesByKey[key] != EMPTY;
    }
    }
    711 changes: 711 additions & 0 deletions Sumtree.spec.ts
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,711 @@
    import { ethers } from 'hardhat';
    import { SumtreeLibrary, TestSumtree } from '../src/contracts';
    import { randomBytes, hexlify, getAddress, AddressLike } from 'ethers';
    import { expect } from 'chai';

    describe('Sumtree', () => {
    let tst: TestSumtree;
    let sumtreeLibrary: SumtreeLibrary;

    before(async () => {
    const stf = await ethers.getContractFactory('SumtreeLibrary');
    sumtreeLibrary = await stf.deploy();
    await sumtreeLibrary.waitForDeployment();
    })

    beforeEach(async () => {
    const f = await ethers.getContractFactory('TestSumtree', {
    libraries: {
    'SumtreeLibrary': await sumtreeLibrary.getAddress()
    }
    });
    tst = await f.deploy()
    await tst.waitForDeployment();
    });

    interface ValidationResult {
    isValid: boolean;
    actualSum?: bigint;
    actualCount?: bigint;
    }

    async function printNode(nodeId: bigint, indent:number = 0, prefix:string='') {
    const n = await tst.node(nodeId);
    console.log(' '.repeat(indent), prefix, `id=${nodeId} s=${n.sum} v=${n.value} c=${n.count} k=${n.key} p=${n.parent}`);
    if( n.left != 0n ) {
    await printNode(n.left, indent + 2, 'L');
    }
    if( n.right != 0n ) {
    await printNode(n.right, indent + 2, 'R');
    }
    }

    async function validateNode(nodeId: bigint): Promise<ValidationResult> {
    if (nodeId === 0n) {
    return { isValid: true, actualSum: BigInt(0), actualCount: 0n };
    }

    const node = await tst.node(nodeId);

    // Validate left subtree
    if (node.left !== 0n) {
    const leftNode = await tst.node(node.left);

    // Check ordering
    if (leftNode.value > node.value ||
    (leftNode.value === node.value && BigInt(leftNode.key) >= BigInt(node.key))) {
    console.error(`Order violation at node ${nodeId} with left child ${node.left}`);
    console.error(`Parent: (${node.value.toString()}, ${node.key})`);
    console.error(`Left: (${leftNode.value.toString()}, ${leftNode.key})`);
    return { isValid: false };
    }

    // Check parent pointer
    if (leftNode.parent !== nodeId) {
    console.error(`Parent pointer mismatch: node ${node.left} should point to ${nodeId}`);
    return { isValid: false };
    }
    }

    // Validate right subtree
    if (node.right !== 0n) {
    const rightNode = await tst.node(node.right);

    // Check ordering
    if (rightNode.value < node.value ||
    (rightNode.value === node.value && BigInt(rightNode.key) <= BigInt(node.key))) {
    console.error(`Order violation at node ${nodeId} with right child ${node.right}`);
    console.error(`Parent: (${node.value.toString()}, ${node.key})`);
    console.error(`Right: (${rightNode.value.toString()}, ${rightNode.key})`);
    return { isValid: false };
    }

    // Check parent pointer
    if (rightNode.parent !== nodeId) {
    console.error(`Parent pointer mismatch: node ${node.right} should point to ${nodeId}`);
    return { isValid: false };
    }
    }

    // Recursively validate children and get their sums and counts
    const leftResult = await validateNode(node.left);
    if (!leftResult.isValid) return { isValid: false };

    const rightResult = await validateNode(node.right);
    if (!rightResult.isValid) return { isValid: false };

    // Calculate actual sum and count
    const actualSum = leftResult.actualSum! + node.value + rightResult.actualSum!;
    const actualCount = leftResult.actualCount! + 1n + rightResult.actualCount!;

    // Validate sum
    if (node.sum !== actualSum) {
    console.error(`Sum mismatch at node ${nodeId}:`);
    console.error(`Expected: ${node.sum.toString()}`);
    console.error(`Actual: ${actualSum.toString()}`);
    return { isValid: false };
    }

    // Validate count
    if (node.count !== actualCount) {
    console.error(`Count mismatch at node ${nodeId}:`);
    console.error(`Expected: ${node.count}`);
    console.error(`Actual: ${actualCount}`);
    return { isValid: false };
    }

    return {
    isValid: true,
    actualSum,
    actualCount
    };
    }

    async function validateTree(root: bigint): Promise<boolean> {
    try {
    const result = await validateNode(root);
    if (result.isValid) {
    //console.log("Tree is valid:");
    //console.log(`Total sum: ${result.actualSum!.toString()}`);
    //console.log(`Total count: ${result.actualCount}`);
    }
    return result.isValid;
    } catch (error) {
    //console.error("Error validating tree:", error);
    return false;
    }
    }

    interface NodeValue {
    value: bigint;
    key: string;
    }

    async function treeToSortedList(): Promise<NodeValue[]> {
    const values: NodeValue[] = [];
    const rootId = await tst.root();

    async function inorderTraversal(nodeId: bigint): Promise<void> {
    if (nodeId === 0n) return;

    const node = await tst.node(nodeId);

    // Traverse left
    if (node.left !== 0n) {
    await inorderTraversal(node.left);
    }

    // Add current node
    values.push({
    value: node.value,
    key: node.key
    });

    // Traverse right
    if (node.right !== 0n) {
    await inorderTraversal(node.right);
    }
    }

    await inorderTraversal(rootId);
    return values;
    }

    it('Simple left rotation', async () => {
    // Tree becomes:
    // 100 200
    // \ -> / \
    // 200 100 300
    // \
    // 300

    // Insert in increasing value order
    const pairs: [bigint,AddressLike][] = [
    [100n, getAddress('0x' + '1'.repeat(40))], // id=1
    [200n, getAddress('0x' + '2'.repeat(40))], // id=2
    [300n, getAddress('0x' + '3'.repeat(40))], // id=3
    ];

    for( const [w, a] of pairs ) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }
    //await printNode(await tst.root());

    expect(await tst.root()).eq(2n);
    expect((await tst.node(2n)).left).eq(1n);
    expect((await tst.node(2n)).right).eq(3n);
    });

    it('Simple right rotation', async () => {
    // Tree becomes:
    // 300 200
    // / -> / \
    // 200 100 300
    // /
    //100

    // Insert in decreasing value order
    const pairs: [bigint,AddressLike][] = [
    [300n, getAddress('0x' + '4'.repeat(40))], // id=1
    [200n, getAddress('0x' + '5'.repeat(40))], // id=2
    [100n, getAddress('0x' + '6'.repeat(40))], // id=3
    ];

    for( const [w, a] of pairs ) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }
    //await printNode(await tst.root());

    expect(await tst.root()).eq(2n);
    expect((await tst.node(2n)).left).eq(3n);
    expect((await tst.node(2n)).right).eq(1n);
    });

    it('Right-right case (double rotation)', async () => {
    // Tree becomes:
    // 100 100 200
    // \ -> \ -> / \
    // 300 200 100 300
    // / \
    // 200 300

    const pairs: [bigint,AddressLike][] = [
    [100n, getAddress('0x' + '7'.repeat(40))], // id=1
    [300n, getAddress('0x' + '8'.repeat(40))], // id=2
    [200n, getAddress('0x' + '9'.repeat(40))], // id=3
    ];

    for( const [w, a] of pairs ) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }
    //await printNode(await tst.root());

    expect(await tst.root()).eq(3n);
    expect((await tst.node(3n)).left).eq(1n);
    expect((await tst.node(3n)).right).eq(2n);
    });

    it('Left-right case (double rotation)', async () => {
    // Tree becomes:
    // 300 300 200
    // / -> / -> / \
    // 100 200 100 300
    // \ /
    // 200 100

    const pairs: [bigint,AddressLike][] = [
    [300n, getAddress('0x' + '7'.repeat(40))], // id=1
    [100n, getAddress('0x' + '8'.repeat(40))], // id=2
    [200n, getAddress('0x' + '9'.repeat(40))], // id=3
    ];

    for( const [w, a] of pairs ) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }
    //await printNode(await tst.root());

    expect(await tst.root()).eq(3n);
    expect((await tst.node(3n)).left).eq(2n);
    expect((await tst.node(3n)).right).eq(1n);
    });

    it('Key-based ordering (same values)', async () => {
    const pairs: [bigint,AddressLike][] = [
    [100n, getAddress('0x' + '1'.repeat(40))], // id=1
    [100n, getAddress('0x' + '2'.repeat(40))], // id=2
    [100n, getAddress('0x' + '3'.repeat(40))], // id=3
    ];

    for( const [w, a] of pairs ) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }
    //await printNode(await tst.root());

    expect(await tst.root()).eq(2n);
    expect((await tst.node(2n)).left).eq(1n);
    expect((await tst.node(2n)).right).eq(3n);
    });

    it('Remove leaf node from balanced tree', async () => {
    // Tree structure:
    // 200
    // / \
    // 100 300
    //
    // After removing 100:
    // 300
    // /
    // 200

    const pairs: [bigint,AddressLike][] = [
    [200n, getAddress('0x' + '1'.repeat(40))], // id=1
    [100n, getAddress('0x' + '2'.repeat(40))], // id=2
    [300n, getAddress('0x' + '3'.repeat(40))], // id=3
    ];

    // Build initial tree
    for (const [w, a] of pairs) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }

    //await printNode(await tst.root());

    // Remove leaf node 100
    const tx = await tst.remove(pairs[1][1]);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);

    //await printNode(await tst.root());

    // Verify structure
    const rootNodeId = await tst.root()
    expect(rootNodeId).eq(3n);
    expect((await tst.node(rootNodeId)).left).eq(1n);
    expect((await tst.node(rootNodeId)).right).eq(0n);
    });

    it('Remove node with one child', async () => {
    // Initial tree:
    // 200
    // / \
    // 100 300
    // /
    // 50
    //
    // After removing 100:
    // 200
    // / \
    // 50 300

    const pairs: [bigint,AddressLike][] = [
    [200n, getAddress('0x' + '1'.repeat(40))], // id=1
    [100n, getAddress('0x' + '2'.repeat(40))], // id=2
    [300n, getAddress('0x' + '3'.repeat(40))], // id=3
    [50n, getAddress('0x' + '4'.repeat(40))], // id=4
    ];

    // Build initial tree
    for (const [w, a] of pairs) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }

    // Remove node 100
    const tx = await tst.remove(pairs[1][1]);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);

    // Verify structure
    expect(await tst.root()).eq(1n);
    expect((await tst.node(1n)).left).eq(4n);
    expect((await tst.node(1n)).right).eq(3n);
    });

    it('Remove node with two children', async () => {
    // Initial tree:
    // 200
    // / \
    // 100 300
    // / \
    // 50 150
    //
    // After removing 100:
    // 200
    // / \
    // 50 300
    // /
    // 150

    const pairs: [bigint,AddressLike][] = [
    [200n, getAddress('0x' + '1'.repeat(40))], // id=1
    [100n, getAddress('0x' + '2'.repeat(40))], // id=2
    [300n, getAddress('0x' + '3'.repeat(40))], // id=3
    [50n, getAddress('0x' + '4'.repeat(40))], // id=4
    [150n, getAddress('0x' + '5'.repeat(40))], // id=5
    ];

    // Build initial tree
    for (const [w, a] of pairs) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }

    //await printNode(await tst.root());

    // Remove node 100
    const tx = await tst.remove(pairs[1][1]);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);

    /*
    await printNode(await tst.root());
    // Verify structure
    expect(await tst.root()).eq(1n);
    expect((await tst.node(1n)).left).eq(5n);
    expect((await tst.node(1n)).right).eq(3n);
    expect((await tst.node(5n)).left).eq(4n);
    */
    });

    it('Remove root node triggers rebalance', async () => {
    // Initial tree:
    // 200
    // / \
    // 100 300
    // / \ \
    // 50 150 400
    //
    // After removing 200:
    // 300
    // / \
    // 100 400
    // / \
    // 50 150

    const pairs: [bigint,AddressLike][] = [
    [200n, getAddress('0x' + '1'.repeat(40))], // id=1
    [100n, getAddress('0x' + '2'.repeat(40))], // id=2
    [300n, getAddress('0x' + '3'.repeat(40))], // id=3
    [50n, getAddress('0x' + '4'.repeat(40))], // id=4
    [150n, getAddress('0x' + '5'.repeat(40))], // id=5
    [400n, getAddress('0x' + '6'.repeat(40))], // id=6
    ];

    // Build initial tree
    for (const [w, a] of pairs) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }

    // Remove root node
    const tx = await tst.remove(pairs[0][1]);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);

    // Verify structure
    expect(await tst.root()).eq(3n);
    expect((await tst.node(3n)).left).eq(2n);
    expect((await tst.node(3n)).right).eq(6n);
    expect((await tst.node(2n)).left).eq(4n);
    expect((await tst.node(2n)).right).eq(5n);
    });

    it('Insert after removal maintains valid tree', async () => {
    // Initial tree:
    // 200
    // / \
    // 100 300
    // / \
    // 50 150

    const pairs: [bigint,AddressLike][] = [
    [200n, getAddress('0x' + '1'.repeat(40))], // id=1
    [100n, getAddress('0x' + '2'.repeat(40))], // id=2
    [300n, getAddress('0x' + '3'.repeat(40))], // id=3
    [50n, getAddress('0x' + '4'.repeat(40))], // id=4
    [150n, getAddress('0x' + '5'.repeat(40))], // id=5
    ];

    // Build initial tree
    for (const [w, a] of pairs) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }

    // Remove node with value 100
    const tx = await tst.remove(pairs[1][1]);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);

    // Insert new nodes to test different scenarios
    const newPairs: [bigint,AddressLike][] = [
    [125n, getAddress('0x' + '6'.repeat(40))], // Between 50 and 150
    [175n, getAddress('0x' + '7'.repeat(40))], // Between 150 and 200
    [250n, getAddress('0x' + '8'.repeat(40))], // Between 200 and 300
    ];

    for (const [w, a] of newPairs) {
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);
    }

    // Verify final tree structure maintains all invariants
    expect(await validateTree(await tst.root())).eq(true);

    // Optional: Print final tree state for debugging
    //await printNode(await tst.root());
    });

    it('Remove should fail for non-existent key', async () => {
    const nonExistentAddress = getAddress('0x' + 'f'.repeat(40));
    await expect(tst.remove(nonExistentAddress))
    .to.be.revertedWithCustomError(sumtreeLibrary, 'KeyNotFound')
    .withArgs(nonExistentAddress);
    });

    it('Verify sum updates after removal', async () => {
    // Add three nodes
    const pairs: [bigint,AddressLike][] = [
    [200n, getAddress('0x' + '1'.repeat(40))],
    [100n, getAddress('0x' + '2'.repeat(40))],
    [300n, getAddress('0x' + '3'.repeat(40))],
    ];

    for (const [w, a] of pairs) {
    const tx = await tst.add(a, w);
    await tx.wait();
    }

    // Get initial sum
    const initialSum = (await tst.node(await tst.root())).sum;
    expect(initialSum).eq(600n); // 200 + 100 + 300

    // Remove middle node (100)
    const tx = await tst.remove(pairs[1][1]);
    await tx.wait();

    // Verify sum is updated
    const finalSum = (await tst.node(await tst.root())).sum;
    expect(finalSum).eq(500n); // 200 + 300
    });

    function validateOrdering(list: NodeValue[]): { isValid: boolean; violations: string[] } {
    const violations: string[] = [];

    for (let i = 0; i < list.length - 1; i++) {
    const current = list[i];
    const next = list[i + 1];

    // Check if next is greater than current
    if (next.value < current.value ||
    (next.value === current.value && next.key.toLowerCase() <= current.key.toLowerCase())) {

    violations.push(
    `Ordering violation at index ${i}:\n` +
    `Current: (value: ${current.value.toString()}, key: ${current.key})\n` +
    `Next: (value: ${next.value.toString()}, key: ${next.key})`
    );
    }
    }

    return {
    isValid: violations.length === 0,
    violations
    };
    }

    it('Random insert & verify pick', async () => {
    for( let i = 0; i < 100; i++ ) {
    const a = getAddress(hexlify(randomBytes(20)));
    const w = 1 + Math.floor(Math.random() * 1000);
    const tx = await tst.add(a, w);
    await tx.wait();
    expect(await validateTree(await tst.root())).eq(true);

    const x = await treeToSortedList();
    const y = validateOrdering(x);
    if(! y.isValid ) {
    console.log(x);
    console.log(y.violations);
    }
    expect(y.isValid).eq(true);
    }
    //await printNode(await tst.root());

    const x = await treeToSortedList();
    const y = validateOrdering(x);
    expect(y.isValid).eq(true);

    let total: bigint = 0n;
    for( const z of x )
    {
    const totalBefore = total;
    total += z.value;

    // If we pick the beginning & end of range we get the expected
    const a = await tst.pick(totalBefore);
    const b = await tst.pick(total - 1n);
    expect(a).eq(b);
    expect(b).eq(z.key);

    // Somewhere in the middle
    const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n));
    expect(c).eq(z.key);
    }
    });

    it('Random insert, remove & verify pick', async () => {
    // Keep track of all addresses and their current presence in tree
    const addressMap = new Map<string, boolean>();
    const existingAddresses: string[] = [];
    let numNodes = 0;

    for (let i = 0; i < 200; i++) {
    // 70% chance to add, 30% chance to remove when we have nodes
    const shouldAdd = numNodes === 0 || Math.random() < 0.7;

    if (shouldAdd) {
    // Add new node
    const a = getAddress(hexlify(randomBytes(20)));
    const w = 1 + Math.floor(Math.random() * 1000);
    const tx = await tst.add(a, w);
    const receipt = await tx.wait();
    //console.log('Insert cost', receipt?.cumulativeGasUsed, numNodes);

    addressMap.set(a.toLowerCase(), true);
    existingAddresses.push(a);
    numNodes++;

    // Verify tree is valid after addition
    expect(await validateTree(await tst.root())).eq(true);
    } else {
    // Remove random existing node
    const indexToRemove = Math.floor(Math.random() * existingAddresses.length);
    const addressToRemove = existingAddresses[indexToRemove];

    if (addressMap.get(addressToRemove.toLowerCase())) {
    const tx = await tst.remove(addressToRemove);
    const receipt = await tx.wait();
    //console.log('Remove cost', receipt?.cumulativeGasUsed, numNodes);

    addressMap.set(addressToRemove.toLowerCase(), false);
    numNodes--;

    // Verify tree is valid after removal
    expect(await validateTree(await tst.root())).eq(true);
    }
    }

    // Verify ordering after each operation
    const x = await treeToSortedList();
    const y = validateOrdering(x);
    if (!y.isValid) {
    console.log("Operation:", i);
    console.log("Current tree:", x);
    console.log("Violations:", y.violations);
    }
    expect(y.isValid).eq(true);

    // Every 10 operations, verify pick functionality for all nodes
    if (i % 10 === 0) {
    const x = await treeToSortedList();
    let total: bigint = 0n;

    for (const z of x) {
    const totalBefore = total;
    total += z.value;

    // Verify picks at range boundaries
    const a = await tst.pick(totalBefore);
    const b = await tst.pick(total - 1n);
    expect(a).eq(b);
    expect(b).eq(z.key);

    // Verify pick in middle of range
    const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n));
    expect(c).eq(z.key);
    }
    }
    }

    // Final verification of entire tree
    const finalList = await treeToSortedList();
    expect(validateOrdering(finalList).isValid).eq(true);

    // Verify final tree size matches our tracking
    expect(finalList.length).eq(numNodes);

    // Final verification of pick functionality for all nodes
    let total: bigint = 0n;
    for (const z of finalList) {
    const totalBefore = total;
    total += z.value;

    const a = await tst.pick(totalBefore);
    const b = await tst.pick(total - 1n);
    expect(a).eq(b);
    expect(b).eq(z.key);

    const c = await tst.pick(totalBefore + ((total - 1n - totalBefore) / 2n));
    expect(c).eq(z.key);
    }
    });
    });
    42 changes: 42 additions & 0 deletions TestSumtree.sol
    Original file line number Diff line number Diff line change
    @@ -0,0 +1,42 @@
    // SPDX-License-Identifier: MIT
    pragma solidity ^0.8.0;

    import {Node, Sumtree, SumtreeLibrary, ValueField, NodeIndex} from '../Sumtree.sol';

    contract TestSumtree {
    Sumtree private st;
    using SumtreeLibrary for Sumtree;

    function add(address k, ValueField v) external {
    st.add(k,v);
    }

    function remove(address key) public returns (ValueField value) {
    return st.remove(key);
    }

    function random(bytes32 seed) public view returns (address) {
    return st.random(seed);
    }

    function pick(uint r) public view returns (address) {
    return st.pick(r);
    }

    function count() public view returns (uint) {
    return st.getCount();
    }

    function total() public view returns (uint) {
    return st.getSum();
    }

    function root() public view returns (NodeIndex) {
    return st.root_id;
    }

    function node(NodeIndex idx) public view returns (Node memory)
    {
    return st.nodes[idx];
    }
    }