Skip to content

Instantly share code, notes, and snippets.

@SolidAlloy
Created August 9, 2025 13:23
Show Gist options
  • Save SolidAlloy/a2cfa0591ef39fef24e499fa7588a6ae to your computer and use it in GitHub Desktop.
Save SolidAlloy/a2cfa0591ef39fef24e499fa7588a6ae to your computer and use it in GitHub Desktop.
A Unity port of SwissTable using Burst intrinsics
using System;
using System.Collections;
using System.Collections.Generic;
using System.Diagnostics;
using System.Diagnostics.CodeAnalysis;
using System.Runtime.CompilerServices;
using System.Runtime.InteropServices;
using System.Runtime.Serialization;
using System.Threading;
using Unity.Burst;
using Unity.Burst.CompilerServices;
using Unity.Burst.Intrinsics;
using Unity.Collections;
using Unity.Jobs;
using Unity.Mathematics;
using UnityEngine;
using Debug = UnityEngine.Debug;
using static DefaultNamespace.SwissTableHelper;
using static Unity.Burst.Intrinsics.X86;
#nullable enable
namespace DefaultNamespace
{
public class SwissTable<TKey, TValue> : IDictionary<TKey, TValue>, IDictionary, IReadOnlyDictionary<TKey, TValue>, ISerializable, IDeserializationCallback where TKey : notnull
{
// Term define:
// Capacity: The maximum number of non-empty items that a hash table can hold before scaling. (come from `EnsureCapacity`)
// count: the real count of stored items
// growth_left: due to the existing of tombstone, non-empty does not mean there is indeed a value.
// bucket: `Entry` in the code, which could hold a key-value pair
// capacity = count + grow_left
// entries.Length = capacity + tombstone + left_by_load_factor
// this contains all meaningfull data but _version and _comparer
// Why comparer is not in inner table:
// When resize(which is more common), we always need to allocate everything but comparer.
// Only when construct from another collection, user could assign a new comparer, we decide to treat this situation as a edge case.
internal struct RawTableInner
{
internal int _bucket_mask;
// TODO: If we could make _controls memory aligned(explicit memory layout or _dummy variable?), we could use load_align rather than load to speed up
// TODO: maybe _controls could be Span and allocate memory from unmanaged memory?
internal byte[] _controls;
internal Entry[]? _entries;
// Number of elements that can be inserted before we need to grow the table
// This need to be calculated individually, for the tombstone("DELETE")
internal int _growth_left;
// number of real values stored in the map
// `items` in rust
internal int _count;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void set_ctrl_h2(int index, int hash)
{
set_ctrl(index, h2(hash));
}
/// Sets a control byte, and possibly also the replicated control byte at
/// the end of the array.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void set_ctrl(int index, byte ctrl)
{
// Replicate the first Group::WIDTH control bytes at the end of
// the array without using a branch:
// - If index >= Group::WIDTH then index == index2.
// - Otherwise index2 == self.bucket_mask + 1 + index.
//
// The very last replicated control byte is never actually read because
// we mask the initial index for unaligned loads, but we write it
// anyways because it makes the set_ctrl implementation simpler.
//
// If there are fewer buckets than Group::WIDTH then this code will
// replicate the buckets at the end of the trailing group. For example
// with 2 buckets and a group size of 4, the control bytes will look
// like this:
//
// Real | Replicated
// ---------------------------------------------
// | [A] | [B] | [EMPTY] | [EMPTY] | [A] | [B] |
// ---------------------------------------------
var index2 = ((index - GROUP_WIDTH) & _bucket_mask) + GROUP_WIDTH;
_controls[index] = ctrl;
_controls[index2] = ctrl;
}
// always insert a new one
// not check replace, caller should make sure
internal int find_insert_slot(int hash)
{
return DispatchFindInsertSlot(hash, _controls, _bucket_mask);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void record_item_insert_at(int index, byte old_ctrl, int hash)
{
_growth_left -= special_is_empty_with_int_return(old_ctrl);
set_ctrl_h2(index, hash);
_count += 1;
}
}
internal IEqualityComparer<TKey>? _comparer;
// each add/expand/shrink will add one version, note that remove does not add version
// For Enumerator, if version is changed, one error will be thrown for enumerator should not changed.
internal int _version;
// enumerator will not throw an error if this changed. Instead, it will refresh data.
internal int _tolerantVersion;
internal RawTableInner rawTable;
public int Count => rawTable._count;
// The real space that the swisstable allocated
// always be the power of 2
// This means the upper bound is not Int32.MaxValue(0x7FFFF_FFFF), but 0x4000_0000
// TODO: Should we throw an expection if user try to grow again when it has been largest?
// This is hard(impossible?) to be larger, for the length of array is limited to 0x7FFFF_FFFF
internal int _buckets => rawTable._bucket_mask + 1;
public SwissTable() : this(0, null) { }
public SwissTable(int capacity) : this(capacity, null) { }
public SwissTable(IEqualityComparer<TKey>? comparer) : this(0, comparer) { }
public SwissTable(int capacity, IEqualityComparer<TKey>? comparer)
{
InitializeInnerTable(capacity);
InitializeComparer(comparer);
}
private void InitializeComparer(IEqualityComparer<TKey>? comparer)
{
if (comparer is not null && comparer != EqualityComparer<TKey>.Default) // first check for null to avoid forcing default comparer instantiation unnecessarily
{
_comparer = comparer;
}
}
public SwissTable(IDictionary<TKey, TValue> dictionary) : this(dictionary, null) { }
public SwissTable(IDictionary<TKey, TValue> dictionary, IEqualityComparer<TKey>? comparer)
{
InitializeComparer(comparer);
if (dictionary == null)
{
InitializeInnerTable(0);
throw new ArgumentNullException(nameof(dictionary));
}
CloneFromCollection(dictionary);
}
public SwissTable(IEnumerable<KeyValuePair<TKey, TValue>> collection) : this(collection, null) { }
public SwissTable(IEnumerable<KeyValuePair<TKey, TValue>> collection, IEqualityComparer<TKey>? comparer)
{
InitializeComparer(comparer);
if (collection == null)
{
InitializeInnerTable(0);
throw new ArgumentNullException(nameof(collection));
}
CloneFromCollection(collection);
}
internal static class HashHelpers
{
private static ConditionalWeakTable<object, SerializationInfo>? s_serializationInfoTable;
public static ConditionalWeakTable<object, SerializationInfo> SerializationInfoTable
{
get
{
if (s_serializationInfoTable == null)
Interlocked.CompareExchange(ref s_serializationInfoTable, new ConditionalWeakTable<object, SerializationInfo>(), null);
return s_serializationInfoTable;
}
}
}
protected SwissTable(SerializationInfo info, StreamingContext context)
{
// We can't do anything with the keys and values until the entire graph has been deserialized
// and we have a resonable estimate that GetHashCode is not going to fail. For the time being,
// we'll just cache this. The graph is not valid until OnDeserialization has been called.
HashHelpers.SerializationInfoTable.Add(this, info);
}
public IEqualityComparer<TKey> Comparer => _comparer ?? EqualityComparer<TKey>.Default;
private void CloneFromCollection(IEnumerable<KeyValuePair<TKey, TValue>> collection)
{
// Initialize could be specified for Dicitonary, defer initialize
// It is likely that the passed-in dictionary is SwissTable<TKey,TValue>. When this is the case,
// avoid the enumerator allocation and overhead by looping through the entries array directly.
// We only do this when dictionary is SwissTable<TKey,TValue> and not a subclass, to maintain
// back-compat with subclasses that may have overridden the enumerator behavior.
if (collection.GetType() == typeof(SwissTable<TKey, TValue>))
{
SwissTable<TKey, TValue> source = (SwissTable<TKey, TValue>)collection;
CloneFromDictionary(source);
return;
}
InitializeInnerTable((collection as ICollection<KeyValuePair<TKey, TValue>>)?.Count ?? 0);
// Fallback path for IEnumerable that isn't a non-subclassed SwissTable<TKey,TValue>.
foreach (KeyValuePair<TKey, TValue> pair in collection)
{
Add(pair.Key, pair.Value);
}
}
private void CloneFromDictionary(SwissTable<TKey, TValue> source)
{
if (source.Count == 0)
{
// TODO: maybe just InitializeInnerTable(0)? could CLR optimsie this?
rawTable = NewEmptyInnerTable();
return;
}
var oldEntries = source.rawTable._entries;
byte[] oldCtrls = source.rawTable._controls;
Debug.Assert(oldEntries != null);
rawTable = NewInnerTableWithControlUninitialized(source._buckets);
if (_comparer == source._comparer)
{
// here is why we initial non-empty innter-table with uninitialized array.
Array.Copy(source.rawTable._controls, rawTable._controls, oldCtrls.Length);
Debug.Assert(rawTable._entries != null);
var newEntries = rawTable._entries;
for (int i = 0; i < oldEntries.Length; i++)
{
if (is_full(rawTable._controls[i]))
{
newEntries[i] = oldEntries[i];
}
}
rawTable._growth_left = source.rawTable._count;
rawTable._count = source.rawTable._count;
return;
}
Array.Fill(rawTable._controls, EMPTY);
// TODO: Maybe we could use IBITMASK to accllerate
for (int i = 0; i < oldEntries.Length; i++)
{
if (is_full(oldCtrls[i]))
{
Add(oldEntries[i].Key, oldEntries[i].Value);
}
}
}
public void Add(TKey key, TValue value)
{
bool modified = TryInsert(key, value, InsertionBehavior.ThrowOnExisting);
Debug.Assert(modified); // If there was an existing key and the Add failed, an exception will already have been thrown.
}
public bool ContainsKey(TKey key) =>
!Unsafe.IsNullRef(ref FindBucket(key));
public bool ContainsValue(TValue value)
{
// TODO: "inline" to get better performance
foreach (var item in new ValueCollection(this))
{
if (EqualityComparer<TValue>.Default.Equals(item, value))
{
return true;
}
}
return false;
}
public bool Remove(TKey key)
{
// TODO: maybe need to duplicate most of code with `Remove(TKey key, out TValue value)` for performance issue, see C# old implementation
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
if (rawTable._entries != null)
{
var index = FindBucketIndex(key);
if (index >= 0)
{
erase(index);
_tolerantVersion++;
return true;
}
}
return false;
}
public bool Remove(TKey key, [MaybeNullWhen(false)] out TValue value)
{
// TODO: maybe need to duplicate most of code with `Remove(TKey key)` for performance issue, see C# old implementation
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
if (rawTable._entries != null)
{
var index = FindBucketIndex(key);
if (index >= 0)
{
value = erase(index);
_tolerantVersion++;
return true;
}
}
value = default;
return false;
}
public bool TryAdd(TKey key, TValue value) =>
TryInsert(key, value, InsertionBehavior.None);
private KeyCollection? _keys;
public KeyCollection Keys => _keys ??= new KeyCollection(this);
private ValueCollection? _values;
public ValueCollection Values => _values ??= new ValueCollection(this);
public TValue this[TKey key]
{
get
{
ref Entry entry = ref FindBucket(key);
if (!Unsafe.IsNullRef(ref entry))
{
return entry.Value;
}
throw new KeyNotFoundException(key.ToString());
}
set
{
bool modified = TryInsert(key, value, InsertionBehavior.OverwriteExisting);
Debug.Assert(modified);
}
}
public bool TryGetValue(TKey key, [MaybeNullWhen(false)] out TValue value)
{
ref Entry entry = ref FindBucket(key);
if (!Unsafe.IsNullRef(ref entry))
{
value = entry.Value;
return true;
}
value = default;
return false;
}
public void Clear()
{
int count = rawTable._count;
if (count > 0)
{
Debug.Assert(rawTable._entries != null, "_entries should be non-null");
Array.Fill(rawTable._controls, EMPTY);
rawTable._count = 0;
rawTable._growth_left = bucket_mask_to_capacity(rawTable._bucket_mask);
// TODO: maybe we could remove this branch to improve perf. Or maybe CLR has optimised this.
if (RuntimeHelpers.IsReferenceOrContainsReferences<TValue>()
|| RuntimeHelpers.IsReferenceOrContainsReferences<TKey>())
{
Array.Clear(rawTable._entries, 0, rawTable._entries.Length);
}
}
}
private static bool IsCompatibleKey(object key)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
return key is TKey;
}
private void CopyTo(KeyValuePair<TKey, TValue>[] array, int index)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if ((uint)index > (uint)array.Length)
{
throw new IndexOutOfRangeException();
}
if (array.Length - index < Count)
{
throw new ArgumentException();
}
CopyToWorker(array, index);
}
// This method not check whether array and index, maybe the name should be CopyToUnsafe?
private void CopyToWorker(KeyValuePair<TKey, TValue>[] array, int index)
{
// TODO: maybe we could fix the array then it might be safe to use load_align
DispatchCopyToArrayFromDictionaryWorker(this, array, index);
}
#region IDictionary
bool ICollection.IsSynchronized => false;
object ICollection.SyncRoot => this;
bool IDictionary.IsFixedSize => false;
bool IDictionary.IsReadOnly => false;
ICollection IDictionary.Keys => Keys;
ICollection IDictionary.Values => Values;
object? IDictionary.this[object key]
{
get
{
if (IsCompatibleKey(key))
{
ref Entry entry = ref FindBucket((TKey)key);
if (!Unsafe.IsNullRef(ref entry))
{
return entry.Value;
}
}
return null;
}
set
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
// ThrowHelper.IfNullAndNullsAreIllegalThenThrow<TValue>(value, nameof(value));
try
{
TKey tempKey = (TKey)key;
try
{
this[tempKey] = (TValue)value!;
}
catch (InvalidCastException)
{
// ThrowHelper.ThrowWrongValueTypeArgumentException(value, typeof(TValue));
}
}
catch (InvalidCastException)
{
// ThrowHelper.ThrowWrongKeyTypeArgumentException(key, typeof(TKey));
}
}
}
void IDictionary.Add(object key, object? value)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
// ThrowHelper.IfNullAndNullsAreIllegalThenThrow<TValue>(value, nameof(value));
try
{
TKey tempKey = (TKey)key;
try
{
Add(tempKey, (TValue)value!);
}
catch (InvalidCastException)
{
// ThrowHelper.ThrowWrongValueTypeArgumentException(value, typeof(TValue));
}
}
catch (InvalidCastException)
{
// ThrowHelper.ThrowWrongKeyTypeArgumentException(key, typeof(TKey));
}
}
bool IDictionary.Contains(object key)
{
if (IsCompatibleKey(key))
{
return ContainsKey((TKey)key);
}
return false;
}
IDictionaryEnumerator IDictionary.GetEnumerator() => new Enumerator(this, Enumerator.DictEntry);
void IDictionary.Remove(object key)
{
if (IsCompatibleKey(key))
{
Remove((TKey)key);
}
}
void ICollection.CopyTo(Array array, int index)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if (array.Rank != 1)
{
// throw new ArgumentException(ExceptionResource.Arg_RankMultiDimNotSupported);
}
if (array.GetLowerBound(0) != 0)
{
// throw new ArgumentException(ExceptionResource.Arg_NonZeroLowerBound);
}
if ((uint)index > (uint)array.Length)
{
throw new IndexOutOfRangeException();
}
if (array.Length - index < Count)
{
// throw new ArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall);
}
if (array is KeyValuePair<TKey, TValue>[] pairs)
{
CopyToWorker(pairs, index);
}
else if (array is DictionaryEntry[] dictEntryArray)
{
foreach (var item in this)
{
dictEntryArray[index++] = new DictionaryEntry(item.Key, item.Value);
}
}
else
{
object[]? objects = array as object[];
if (objects == null)
{
// throw new ArgumentException_Argument_InvalidArrayType();
}
try
{
foreach (var item in this)
{
objects[index++] = new KeyValuePair<TKey, TValue>(item.Key, item.Value);
}
}
catch (ArrayTypeMismatchException)
{
// throw new ArgumentException_Argument_InvalidArrayType();
}
}
}
#endregion
#region IReadOnlyDictionary<TKey, TValue>
IEnumerable<TKey> IReadOnlyDictionary<TKey, TValue>.Keys => Keys;
IEnumerable<TValue> IReadOnlyDictionary<TKey, TValue>.Values => Values;
#endregion
#region IDictionary<TKey, TValue>
ICollection<TKey> IDictionary<TKey, TValue>.Keys => Keys;
ICollection<TValue> IDictionary<TKey, TValue>.Values => Values;
#endregion
#region ICollection<KeyValuePair<TKey, TValue>>
bool ICollection<KeyValuePair<TKey, TValue>>.IsReadOnly => false;
void ICollection<KeyValuePair<TKey, TValue>>.Add(KeyValuePair<TKey, TValue> keyValuePair) =>
Add(keyValuePair.Key, keyValuePair.Value);
bool ICollection<KeyValuePair<TKey, TValue>>.Contains(KeyValuePair<TKey, TValue> keyValuePair)
{
ref Entry bucket = ref FindBucket(keyValuePair.Key);
if (!Unsafe.IsNullRef(ref bucket) && EqualityComparer<TValue>.Default.Equals(bucket.Value, keyValuePair.Value))
{
return true;
}
return false;
}
void ICollection<KeyValuePair<TKey, TValue>>.CopyTo(KeyValuePair<TKey, TValue>[] array, int index) =>
CopyTo(array, index);
bool ICollection<KeyValuePair<TKey, TValue>>.Remove(KeyValuePair<TKey, TValue> keyValuePair)
{
ref Entry bucket = ref FindBucket(keyValuePair.Key);
if (!Unsafe.IsNullRef(ref bucket) && EqualityComparer<TValue>.Default.Equals(bucket.Value, keyValuePair.Value))
{
Remove(keyValuePair.Key);
return true;
}
return false;
}
#endregion
#region IEnumerable<KeyValuePair<TKey, TValue>>
IEnumerator<KeyValuePair<TKey, TValue>> IEnumerable<KeyValuePair<TKey, TValue>>.GetEnumerator() =>
new Enumerator(this, Enumerator.KeyValuePair);
#endregion
#region IEnumerable
IEnumerator IEnumerable.GetEnumerator() => new Enumerator(this, Enumerator.KeyValuePair);
#endregion
#region Serialization/Deserialization
// constants for Serialization/Deserialization
private const string VersionName = "Version"; // Do not rename (binary serialization)
private const string HashSizeName = "HashSize"; // Do not rename (binary serialization)
private const string KeyValuePairsName = "KeyValuePairs"; // Do not rename (binary serialization)
private const string ComparerName = "Comparer"; // Do not rename (binary serialization)
public virtual void GetObjectData(SerializationInfo info, StreamingContext context)
{
if (info == null)
{
throw new ArgumentNullException(nameof(info));
}
info.AddValue(VersionName, _version);
info.AddValue(ComparerName, Comparer, typeof(IEqualityComparer<TKey>));
info.AddValue(HashSizeName, bucket_mask_to_capacity(rawTable._bucket_mask));
if (rawTable._entries != null)
{
var array = new KeyValuePair<TKey, TValue>[Count];
// This is always safe, for the array is allocated by ourself. There are always enough space.
CopyToWorker(array, 0);
info.AddValue(KeyValuePairsName, array, typeof(KeyValuePair<TKey, TValue>[]));
}
}
public virtual void OnDeserialization(object? sender)
{
HashHelpers.SerializationInfoTable.TryGetValue(this, out SerializationInfo? siInfo);
if (siInfo == null)
{
// We can return immediately if this function is called twice.
// Note we remove the serialization info from the table at the end of this method.
return;
}
int realVersion = siInfo.GetInt32(VersionName);
int hashsize = siInfo.GetInt32(HashSizeName);
_comparer = (IEqualityComparer<TKey>)siInfo.GetValue(ComparerName, typeof(IEqualityComparer<TKey>))!; // When serialized if comparer is null, we use the default.
InitializeInnerTable(hashsize);
if (hashsize != 0)
{
KeyValuePair<TKey, TValue>[]? array = (KeyValuePair<TKey, TValue>[]?)
siInfo.GetValue(KeyValuePairsName, typeof(KeyValuePair<TKey, TValue>[]));
if (array == null)
{
// ThrowHelper.ThrowSerializationException(ExceptionResource.Serialization_MissingKeys);
}
for (int i = 0; i < array.Length; i++)
{
if (array[i].Key == null)
{
// ThrowHelper.ThrowSerializationException(ExceptionResource.Serialization_NullKey);
}
Add(array[i].Key, array[i].Value);
}
}
_version = realVersion;
HashHelpers.SerializationInfoTable.Remove(this);
}
#endregion
public Enumerator GetEnumerator() => new Enumerator(this, Enumerator.KeyValuePair);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private int GetHashCodeOfKey(TKey key)
{
return (_comparer == null) ? key.GetHashCode() : _comparer.GetHashCode(key);
}
/// <summary>
/// The real implementation of any public insert behavior
/// Do some extra work other than insert
/// </summary>
/// <param name="key"></param>
/// <param name="value"></param>
/// <param name="behavior"></param>
/// <returns></returns>
private bool TryInsert(TKey key, TValue value, InsertionBehavior behavior)
{
// NOTE: this method is mirrored in CollectionsMarshal.GetValueRefOrAddDefault below.
// If you make any changes here, make sure to keep that version in sync as well.
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
var hashOfKey = GetHashCodeOfKey(key);
ref var bucket = ref FindBucket(key, hashOfKey);
// replace
if (!Unsafe.IsNullRef(ref bucket))
{
if (behavior == InsertionBehavior.OverwriteExisting)
{
bucket.Key = key;
bucket.Value = value;
return true;
}
if (behavior == InsertionBehavior.ThrowOnExisting)
{
// ThrowHelper.ThrowAddingDuplicateWithKeyArgumentException(key);
}
// InsertionBehavior.None
return false;
}
// insert new
// We can avoid growing the table once we have reached our load
// factor if we are replacing a tombstone(Delete). This works since the
// number of EMPTY slots does not change in this case.
var index = rawTable.find_insert_slot(hashOfKey);
var old_ctrl = rawTable._controls[index];
if (rawTable._growth_left == 0 && special_is_empty(old_ctrl))
{
EnsureCapacityWorker(rawTable._count + 1);
index = rawTable.find_insert_slot(hashOfKey);
}
Debug.Assert(rawTable._entries != null);
rawTable.record_item_insert_at(index, old_ctrl, hashOfKey);
ref var targetEntry = ref rawTable._entries[index];
targetEntry.Key = key;
targetEntry.Value = value;
_version++;
return true;
}
/// <summary>
/// A helper class containing APIs exposed through <see cref="Runtime.InteropServices.CollectionsMarshal"/>.
/// These methods are relatively niche and only used in specific scenarios, so adding them in a separate type avoids
/// the additional overhead on each <see cref="SwissTable{TKey,TValue}"/> instantiation, especially in AOT scenarios.
/// </summary>
internal static class CollectionsMarshalHelper
{
/// <inheritdoc cref="Runtime.InteropServices.CollectionsMarshal.GetValueRefOrAddDefault{TKey, TValue}(SwissTable{TKey,TValue}, TKey, out bool)"/>
public static ref TValue? GetValueRefOrAddDefault(SwissTable<TKey, TValue> dictionary, TKey key, out bool exists)
{
// NOTE: this method is mirrored by SwissTable<TKey, TValue>.TryInsert above.
// If you make any changes here, make sure to keep that version in sync as well.
ref var bucket = ref dictionary.FindBucket(key);
// replace
if (!Unsafe.IsNullRef(ref bucket))
{
exists = true;
return ref bucket.Value!;
}
// insert new
var hashCode = dictionary.GetHashCodeOfKey(key);
// We can avoid growing the table once we have reached our load
// factor if we are replacing a tombstone(Delete). This works since the
// number of EMPTY slots does not change in this case.
var index = dictionary.rawTable.find_insert_slot(hashCode);
var old_ctrl = dictionary.rawTable._controls[index];
if (dictionary.rawTable._growth_left == 0 && special_is_empty(old_ctrl))
{
dictionary.EnsureCapacityWorker(dictionary.rawTable._count + 1);
index = dictionary.rawTable.find_insert_slot(hashCode);
}
Debug.Assert(dictionary.rawTable._entries != null);
dictionary.rawTable.record_item_insert_at(index, old_ctrl, hashCode);
dictionary.rawTable._entries[index].Key = key;
dictionary.rawTable._entries[index].Value = default!;
dictionary._version++;
exists = false;
return ref dictionary.rawTable._entries[index].Value!;
}
}
internal struct Entry
{
internal TKey Key;
internal TValue Value;
}
// allocate and initialize when with real capacity
// Note that the _entries might still not allocated
// this means we do not want to use any existing data, including resize or use dictionary initialize
// `realisticCapacity` is any positive number
[SkipLocalsInit]
private void InitializeInnerTable(int realisticCapacity)
{
if (realisticCapacity < 0)
{
// ThrowHelper.ThrowArgumentOutOfRangeException(nameof(capacity));
}
if (realisticCapacity == 0)
{
rawTable = NewEmptyInnerTable();
return;
}
var idealCapacity = capacity_to_buckets(realisticCapacity);
rawTable = NewInnerTableWithControlUninitialized(idealCapacity);
Array.Fill(rawTable._controls, EMPTY);
}
// TODO: check whether we need NonRandomizedStringEqualityComparer
// If we need rehash, we should learn from rust.
// resize to 0 capaciry is a special simple case, Which is not handled here.
private void Grow(int realisticCapacity)
{
Debug.Assert(rawTable._entries != null);
var idealCapacity = capacity_to_buckets(realisticCapacity);
GrowWorker(idealCapacity);
}
[SkipLocalsInit]
private void GrowWorker(int idealEntryLength)
{
Debug.Assert(idealEntryLength >= rawTable._count);
var newTable = NewInnerTableWithControlUninitialized(idealEntryLength);
Array.Fill(newTable._controls, EMPTY);
Debug.Assert(rawTable._entries != null);
Debug.Assert(newTable._entries != null);
Debug.Assert(newTable._count == 0);
Debug.Assert(newTable._growth_left >= rawTable._count);
// We can use a simple version of insert() here since:
// - there are no DELETED entries.
// - we know there is the same enough space in the table.
byte[] oldCtrls = rawTable._controls;
Entry[] oldEntries = rawTable._entries;
Entry[] newEntries = newTable._entries;
int length = rawTable._entries.Length;
// TODO: Maybe we could use IBITMASK to accllerate
for (int i = 0; i < length; i++)
{
if (is_full(oldCtrls[i]))
{
var key = oldEntries[i].Key;
var hash = GetHashCodeOfKey(key);
var index = newTable.find_insert_slot(hash);
newTable.set_ctrl_h2(index, hash);
newEntries[index] = oldEntries[i];
}
}
newTable._growth_left -= rawTable._count;
newTable._count = rawTable._count;
rawTable = newTable;
}
/// <summary>
/// Ensures that the dictionary can hold up to 'capacity' entries without any further expansion of its backing storage
/// </summary>
public int EnsureCapacity(int capacity)
{
// "capacity" is `this._count + this._growth_left` in the new implementation
if (capacity < 0)
{
// ThrowHelper.ThrowArgumentOutOfRangeException(nameof(capacity));
}
int currentCapacity = rawTable._count + rawTable._growth_left;
if (currentCapacity >= capacity)
{
return currentCapacity;
}
EnsureCapacityWorker(capacity);
return rawTable._count + rawTable._growth_left;
}
private void EnsureCapacityWorker(int capacity)
{
_version++;
if (rawTable._entries == null)
{
InitializeInnerTable(capacity);
}
else
{
Grow(capacity);
}
}
/// <summary>
/// Sets the capacity of this dictionary to what it would be if it had been originally initialized with all its entries
/// </summary>
/// <remarks>
/// This method can be used to minimize the memory overhead
/// once it is known that no new elements will be added.
///
/// To allocate minimum size storage array, execute the following statements:
///
/// dictionary.Clear();
/// dictionary.TrimExcess();
/// </remarks>
public void TrimExcess() => TrimExcess(Count);
/// <summary>
/// Sets the capacity of this dictionary to hold up 'capacity' entries without any further expansion of its backing storage
/// </summary>
/// <remarks>
/// This method can be used to minimize the memory overhead
/// once it is known that no new elements will be added.
/// </remarks>
public void TrimExcess(int capacity)
{
if (capacity < Count)
{
// ThrowHelper.ThrowArgumentOutOfRangeException(nameof(capacity));
}
if (capacity == 0)
{
_version++;
// TODO: No need to initialize if _entry is null.
InitializeInnerTable(capacity);
return;
}
var idealBuckets = capacity_to_buckets(capacity);
// TODO: if the length is same, we might not need to resize, reference `rehash_in_place` in rust implementation.
if (idealBuckets <= _buckets)
{
_version++;
GrowWorker(idealBuckets);
}
}
/// <summary>
/// Creates a new empty hash table without allocating any memory, using the
/// given allocator.
///
/// In effect this returns a table with exactly 1 bucket. However we can
/// leave the data pointer dangling since that bucket is never written to
/// due to our load factor forcing us to always have at least 1 free bucket.
/// </summary>
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static RawTableInner NewEmptyInnerTable()
{
return new RawTableInner
{
_bucket_mask = 0,
_controls = DispatchGetEmptyControls(),
_entries = null,
_growth_left = 0,
_count = 0
};
}
/// <summary>
/// Allocates a new hash table with the given number of buckets.
///
/// The control bytes are initialized with EMPTY.
/// </summary>
// unlike rust, we never cares about out of memory
// TODO: Maybe ref to improve performance?
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static RawTableInner NewInnerTableWithControlUninitialized(int buckets)
{
// Debug.Assert(BitOperations.IsPow2(buckets));
return new RawTableInner
{
_bucket_mask = buckets - 1,
_controls = new byte[buckets + GROUP_WIDTH],
_entries = new Entry[buckets],
_growth_left = bucket_mask_to_capacity(buckets - 1),
_count = 0
};
}
internal ref TValue FindValue(TKey key)
{
// TODO: We might choose to dulpcate here too just like FindBucketIndex, but not now.
ref Entry bucket = ref FindBucket(key);
if (Unsafe.IsNullRef(ref bucket))
{
return ref Unsafe.NullRef<TValue>();
}
return ref bucket.Value;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ref Entry FindBucket(TKey key)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
var hash = GetHashCodeOfKey(key);
return ref DispatchFindBucketOfDictionary(this, key, hash);
}
// Sometimes we need to reuse hash, do not calcualte it twice
// the caller should check key is not null, for it should get hash first.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private ref Entry FindBucket(TKey key, int hashOfKey)
{
return ref DispatchFindBucketOfDictionary(this, key, hashOfKey);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private int FindBucketIndex(TKey key)
{
if (key == null)
{
throw new ArgumentNullException(nameof(key));
}
return DispatchFindBucketIndexOfDictionary(this, key);
}
private TValue erase(int index)
{
// Attention, we could not just only set mark to `Deleted` to assume it is deleted, the reference is still here, and GC would not collect it.
Debug.Assert(is_full(rawTable._controls[index]));
Debug.Assert(rawTable._entries != null, "entries should be non-null");
var isEraseSafeToSetEmptyControlFlag = DispatchIsEraseSafeToSetEmptyControlFlag(rawTable._bucket_mask, rawTable._controls, index);
TValue res;
byte ctrl;
if (isEraseSafeToSetEmptyControlFlag)
{
ctrl = EMPTY;
rawTable._growth_left += 1;
}
else
{
ctrl = DELETED;
}
rawTable.set_ctrl(index, ctrl);
rawTable._count -= 1;
res = rawTable._entries[index].Value;
// TODO: maybe we could remove this branch to improve perf. Or maybe CLR has optimised this.
if (RuntimeHelpers.IsReferenceOrContainsReferences<TKey>())
{
rawTable._entries[index].Key = default!;
}
if (RuntimeHelpers.IsReferenceOrContainsReferences<TValue>())
{
rawTable._entries[index].Value = default!;
}
return res;
}
/// Returns the number of buckets needed to hold the given number of items,
/// taking the maximum load factor into account.
private static int capacity_to_buckets(int cap)
{
Debug.Assert(cap > 0);
var capacity = (uint)cap;
// For small tables we require at least 1 empty bucket so that lookups are
// guaranteed to terminate if an element doesn't exist in the table.
if (capacity < 8)
{
// We don't bother with a table size of 2 buckets since that can only
// hold a single element. Instead we skip directly to a 4 bucket table
// which can hold 3 elements.
return (capacity < 4 ? 4 : 8);
}
// Otherwise require 1/8 buckets to be empty (87.5% load)
//
// Be careful when modifying this, calculate_layout relies on the
// overflow check here.
uint adjusted_capacity;
switch (capacity)
{
// 0x01FFFFFF is the max value that would not overflow when *8
case <= 0x01FFFFFF:
adjusted_capacity = unchecked(capacity * 8 / 7);
break;
// 0x37FFFFFF is the max value that smaller than 0x0400_0000 after *8/7
case <= 0x37FFFFFF:
return 0x4000_0000;
default:
throw new Exception("capacity overflow");
}
// Any overflows will have been caught by the checked_mul. Also, any
// rounding errors from the division above will be cleaned up by
// next_power_of_two (which can't overflow because of the previous divison).
return nextPowerOfTwo(adjusted_capacity);
static int nextPowerOfTwo(uint num)
{
return (int)(0x01u << (32 - math.lzcnt(num)));
}
}
/// Returns the maximum effective capacity for the given bucket mask, taking
/// the maximum load factor into account.
private static int bucket_mask_to_capacity(int bucket_mask)
{
if (bucket_mask < 8)
{
// For tables with 1/2/4/8 buckets, we always reserve one empty slot.
// Keep in mind that the bucket mask is one less than the bucket count.
return bucket_mask;
}
// For larger tables we reserve 12.5% of the slots as empty.
return ((bucket_mask + 1) >> 3) * 7; // bucket_mask / 8 * 7, but it will generate a bit more complex for int, maybe we should use uint?
}
public struct Enumerator : IEnumerator<KeyValuePair<TKey, TValue>>, IDictionaryEnumerator
{
private readonly SwissTable<TKey, TValue> _dictionary;
private readonly int _version;
private readonly int _tolerantVersion;
private KeyValuePair<TKey, TValue> _current;
private BitMaskUnion _currentBitMask;
internal int _currentCtrlOffset;
private readonly int _getEnumeratorRetType;
private bool _isValid; // valid when _current has correct value.
internal const int DictEntry = 1;
internal const int KeyValuePair = 2;
internal Enumerator(SwissTable<TKey, TValue> dictionary, int getEnumeratorRetType)
{
_dictionary = dictionary;
_version = dictionary._version;
_tolerantVersion = dictionary._tolerantVersion;
_getEnumeratorRetType = getEnumeratorRetType;
_current = default;
_currentCtrlOffset = 0;
_isValid = false;
_currentBitMask = DispatchGetMatchFullBitMask(_dictionary.rawTable._controls, 0);
}
#region IDictionaryEnumerator
DictionaryEntry IDictionaryEnumerator.Entry
{
get
{
if (!_isValid)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen();
}
Debug.Assert(_current.Key != null);
return new DictionaryEntry(_current.Key, _current.Value);
}
}
object IDictionaryEnumerator.Key
{
get
{
if (!_isValid)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen();
}
Debug.Assert(_current.Key != null);
return _current.Key;
}
}
object? IDictionaryEnumerator.Value
{
get
{
if (!_isValid)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen();
}
return _current.Value;
}
}
#endregion
#region IEnumerator<KeyValuePair<TKey, TValue>>
public KeyValuePair<TKey, TValue> Current => _current;
object? IEnumerator.Current
{
get
{
if (!_isValid)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen();
}
Debug.Assert(_current.Key != null);
if (_getEnumeratorRetType == DictEntry)
{
return new DictionaryEntry(_current.Key, _current.Value);
}
return new KeyValuePair<TKey, TValue>(_current.Key, _current.Value);
}
}
public void Dispose() { }
public bool MoveNext()
{
ref var entry = ref DispatchMoveNextDictionary(_version, _tolerantVersion, _dictionary, ref _currentCtrlOffset, ref _currentBitMask);
if (!Unsafe.IsNullRef(ref entry))
{
_isValid = true;
_current = new KeyValuePair<TKey, TValue>(entry.Key, entry.Value);
return true;
}
_current = default;
_isValid = false;
return false;
}
void IEnumerator.Reset()
{
if (_version != _dictionary._version)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion();
}
_current = default;
_currentCtrlOffset = 0;
_isValid = false;
_currentBitMask = DispatchGetMatchFullBitMask(_dictionary.rawTable._controls, 0);
}
#endregion
}
// [DebuggerTypeProxy(typeof(DictionaryKeyCollectionDebugView<,>))]
[DebuggerDisplay("Count = {Count}")]
public sealed class KeyCollection : ICollection<TKey>, ICollection, IReadOnlyCollection<TKey>
{
private readonly SwissTable<TKey, TValue> _dictionary;
public KeyCollection(SwissTable<TKey, TValue> dictionary)
{
if (dictionary == null)
{
throw new ArgumentNullException(nameof(dictionary));
}
_dictionary = dictionary;
}
public Enumerator GetEnumerator() => new Enumerator(_dictionary);
public void CopyTo(TKey[] array, int index)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
Debug.Assert(array != null);
if (index < 0 || index > array.Length)
{
throw new IndexOutOfRangeException();
}
if (array.Length - index < _dictionary.Count)
{
// throw new ArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall);
}
// TODO: we might also use SIMD to pass through the control bytes, which would provide better performance for spare situation.
foreach (var item in this)
{
array[index++] = item;
}
}
public int Count => _dictionary.Count;
bool ICollection<TKey>.IsReadOnly => true;
void ICollection<TKey>.Add(TKey item) =>
throw new NotSupportedException();
void ICollection<TKey>.Clear() =>
throw new NotSupportedException();
bool ICollection<TKey>.Contains(TKey item) =>
_dictionary.ContainsKey(item);
bool ICollection<TKey>.Remove(TKey item)
{
throw new NotSupportedException();
return false;
}
IEnumerator<TKey> IEnumerable<TKey>.GetEnumerator() => new Enumerator(_dictionary);
IEnumerator IEnumerable.GetEnumerator() => new Enumerator(_dictionary);
void ICollection.CopyTo(Array array, int index)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if (array.Rank != 1)
{
// throw new ArgumentException(ExceptionResource.Arg_RankMultiDimNotSupported);
}
if (array.GetLowerBound(0) != 0)
{
// throw new ArgumentException(ExceptionResource.Arg_NonZeroLowerBound);
}
if ((uint)index > (uint)array.Length)
{
throw new IndexOutOfRangeException();
}
if (array.Length - index < _dictionary.Count)
{
// throw new ArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall);
}
if (array is TKey[] keys)
{
CopyTo(keys, index);
}
else
{
object[]? objects = array as object[];
if (objects == null)
{
// throw new ArgumentException_Argument_InvalidArrayType();
}
try
{
foreach (var item in this)
{
objects[index++] = item;
}
}
catch (ArrayTypeMismatchException)
{
// throw new ArgumentException_Argument_InvalidArrayType();
}
}
}
bool ICollection.IsSynchronized => false;
object ICollection.SyncRoot => ((ICollection)_dictionary).SyncRoot;
public struct Enumerator : IEnumerator<TKey>, IEnumerator
{
private readonly SwissTable<TKey, TValue> _dictionary;
private readonly int _version;
private readonly int _tolerantVersion;
private BitMaskUnion _currentBitMask;
internal int _currentCtrlOffset;
private bool _isValid; // valid when _current has correct value.
private TKey? _current;
internal Enumerator(SwissTable<TKey, TValue> dictionary)
{
_dictionary = dictionary;
_version = dictionary._version;
_tolerantVersion = dictionary._tolerantVersion;
_current = default;
_currentCtrlOffset = 0;
_isValid = false;
_currentBitMask = DispatchGetMatchFullBitMask(_dictionary.rawTable._controls, 0);
}
public void Dispose() { }
public bool MoveNext()
{
ref var entry = ref DispatchMoveNextDictionary(_version, _tolerantVersion, _dictionary, ref _currentCtrlOffset, ref _currentBitMask);
if (!Unsafe.IsNullRef(ref entry))
{
_isValid = true;
_current = entry.Key;
return true;
}
_current = default;
_isValid = false;
return false;
}
public TKey Current => _current!;
object? IEnumerator.Current
{
get
{
if (!_isValid)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen();
}
return _current;
}
}
void IEnumerator.Reset()
{
if (_version != _dictionary._version)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion();
}
_current = default;
_currentCtrlOffset = 0;
_isValid = false;
_currentBitMask = DispatchGetMatchFullBitMask(_dictionary.rawTable._controls, 0);
}
}
}
// [DebuggerTypeProxy(typeof(DictionaryValueCollectionDebugView<,>))]
[DebuggerDisplay("Count = {Count}")]
public sealed class ValueCollection : ICollection<TValue>, ICollection, IReadOnlyCollection<TValue>
{
private readonly SwissTable<TKey, TValue> _dictionary;
public ValueCollection(SwissTable<TKey, TValue> dictionary)
{
if (dictionary == null)
{
throw new ArgumentNullException(nameof(dictionary));
}
_dictionary = dictionary;
}
public Enumerator GetEnumerator() => new Enumerator(_dictionary);
public void CopyTo(TValue[] array, int index)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if ((uint)index > array.Length)
{
throw new IndexOutOfRangeException();
}
if (array.Length - index < _dictionary.Count)
{
// throw new ArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall);
}
foreach (var item in this)
{
array[index++] = item;
}
}
public int Count => _dictionary.Count;
bool ICollection<TValue>.IsReadOnly => true;
void ICollection<TValue>.Add(TValue item) =>
throw new NotSupportedException();
bool ICollection<TValue>.Remove(TValue item)
{
throw new NotSupportedException();
return false;
}
void ICollection<TValue>.Clear() =>
throw new NotSupportedException();
bool ICollection<TValue>.Contains(TValue item) => _dictionary.ContainsValue(item);
IEnumerator<TValue> IEnumerable<TValue>.GetEnumerator() => new Enumerator(_dictionary);
IEnumerator IEnumerable.GetEnumerator() => new Enumerator(_dictionary);
void ICollection.CopyTo(Array array, int index)
{
if (array == null)
{
throw new ArgumentNullException(nameof(array));
}
if (array.Rank != 1)
{
// throw new ArgumentException(ExceptionResource.Arg_RankMultiDimNotSupported);
}
if (array.GetLowerBound(0) != 0)
{
// throw new ArgumentException(ExceptionResource.Arg_NonZeroLowerBound);
}
if ((uint)index > (uint)array.Length)
{
throw new IndexOutOfRangeException();
}
if (array.Length - index < _dictionary.Count)
{
// throw new ArgumentException(ExceptionResource.Arg_ArrayPlusOffTooSmall);
}
if (array is TValue[] values)
{
CopyTo(values, index);
}
else
{
object[]? objects = array as object[];
if (objects == null)
{
// throw new ArgumentException_Argument_InvalidArrayType();
}
try
{
foreach (var item in this)
{
objects[index++] = item!;
}
}
catch (ArrayTypeMismatchException)
{
// throw new ArgumentException_Argument_InvalidArrayType();
}
}
}
bool ICollection.IsSynchronized => false;
object ICollection.SyncRoot => ((ICollection)_dictionary).SyncRoot;
public struct Enumerator : IEnumerator<TValue>, IEnumerator
{
private readonly SwissTable<TKey, TValue> _dictionary;
private readonly int _version;
private readonly int _tolerantVersion;
private BitMaskUnion _currentBitMask;
internal int _currentCtrlOffset;
private bool _isValid; // valid when _current has correct value.
private TValue? _current;
internal Enumerator(SwissTable<TKey, TValue> dictionary)
{
_dictionary = dictionary;
_version = dictionary._version;
_tolerantVersion = dictionary._tolerantVersion;
_current = default;
_currentCtrlOffset = 0;
_isValid = false;
_currentBitMask = DispatchGetMatchFullBitMask(_dictionary.rawTable._controls, 0);
}
public void Dispose() { }
public bool MoveNext()
{
ref var entry = ref DispatchMoveNextDictionary(_version, _tolerantVersion, _dictionary, ref _currentCtrlOffset, ref _currentBitMask);
if (!Unsafe.IsNullRef(ref entry))
{
_isValid = true;
_current = entry.Value;
return true;
}
_current = default;
_isValid = false;
return false;
}
public TValue Current => _current!;
object? IEnumerator.Current
{
get
{
if (!_isValid)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumOpCantHappen();
}
return _current;
}
}
void IEnumerator.Reset()
{
if (_version != _dictionary._version)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion();
}
_current = default;
_currentCtrlOffset = 0;
_isValid = false;
_currentBitMask = DispatchGetMatchFullBitMask(_dictionary.rawTable._controls, 0);
}
}
}
}
/// <summary>
/// A bit mask which contains the result of a `Match` operation on a `Group` and
/// allows iterating through them.
///
/// The bit mask is arranged so that low-order bits represent lower memory
/// addresses for group match results.
///
/// For implementation reasons, the bits in the set may be sparsely packed, so
/// that there is only one bit-per-byte used (the high bit, 7). If this is the
/// case, `BITMASK_STRIDE` will be 8 to indicate a divide-by-8 should be
/// performed on counts/indices to normalize this difference. `BITMASK_MASK` is
/// similarly a mask of all the actually-used bits.
/// </summary>
// The generic is for performance.
// All implementations should be struct for performance, however, if the returned type is interface, there will be boxing and unboxing.
// And inline will not work too.
// So we need to pass a implementation.
internal interface IBitMask<BitMaskImpl> where BitMaskImpl : struct, IBitMask<BitMaskImpl>
{
/// <summary>
/// Returns a new `BitMask` with all bits inverted.
/// </summary>
/// <returns></returns>
BitMaskImpl Invert();
/// <summary>
/// Returns a new `BitMask` with the lowest bit removed.
/// </summary>
/// <returns></returns>
BitMaskImpl RemoveLowestBit();
/// <summary>
/// Returns a new `BitMask` with the internal data logic and.
/// </summary>
/// <param name="bitMask"> must be the same type with caller</param>
/// <returns></returns>
BitMaskImpl And(BitMaskImpl bitMask);
/// <summary>
/// Returns whether the `BitMask` has at least one set bit.
/// </summary>
/// <returns></returns>
bool AnyBitSet();
/// <summary>
/// Returns the first set bit in the `BitMask`, if there is one.
/// TODO: use negative rather than nullable to represent no bit set.
/// </summary>
/// <returns>
/// negative means not any bit is set.
/// </returns>
int LowestSetBit();
/// <summary>
/// Returns the first set bit in the `BitMask`, if there is one. The
/// bitmask must not be empty.
/// </summary>
/// <returns></returns>
// #[cfg(feature = "nightly")]
int LowestSetBitNonzero();
/// <summary>
/// Returns the number of trailing zeroes in the `BitMask`.
/// </summary>
/// <returns></returns>
int TrailingZeros();
/// <summary>
/// Returns the number of leading zeroes in the `BitMask`.
/// </summary>
/// <returns></returns>
int LeadingZeros();
}
/// After C#11, `static_empty`, `create`, `load` and `load_aligned` should become static abstract mehod
internal interface IGroup<BitMaskImpl, GroupImpl>
where BitMaskImpl : struct, IBitMask<BitMaskImpl>
where GroupImpl : struct, IGroup<BitMaskImpl, GroupImpl>
{
///// <summary>
///// Returns a full group of empty bytes, suitable for use as the initial
///// value for an empty hash table.
///// </summary>
///// <returns></returns>
////byte[] static_empty { get; }
///// <summary>
///// The bytes that the group data ocupies
///// </summary>
///// <remarks>
///// The implementation should have `readonly` modifier
///// </remarks>
////int WIDTH { get; }
////unsafe GroupImpl load(byte* ptr);
///// <summary>
///// Loads a group of bytes starting at the given address, which must be
///// aligned to the WIDTH
///// </summary>
///// <param name="ptr"></param>
///// <returns></returns>
////unsafe GroupImpl load_aligned(byte* ptr);
/// <summary>
/// Performs the following transformation on all bytes in the group:
/// - `EMPTY => EMPTY`
/// - `DELETED => EMPTY`
/// - `FULL => DELETED`
/// </summary>
/// <returns></returns>
GroupImpl convert_special_to_empty_and_full_to_deleted();
/// <summary>
/// Stores the group of bytes to the given address, which must be
/// aligned to WIDTH
/// </summary>
/// <param name="ptr"></param>
unsafe void StoreAligned(byte* ptr);
/// <summary>
/// Returns a `BitMask` indicating all bytes in the group which have
/// the given value.
/// </summary>
/// <param name="b"></param>
/// <returns></returns>
BitMaskImpl MatchByte(byte b);
// <summary>
// Returns a `GroupImpl` with given byte brodcast.
// </summary>
// <param name="group"></param>
// <returns></returns>
// match_byte is good enough, however, we do not have readonly parameter now,
// so we need add this as an optimsation.
// GroupImpl create(byte b);
// match_byte is good enough, however, we do not have readonly parameter now,
// so we need add this as an optimsation.
/// <summary>
/// Returns a `BitMask` indicating all bytes in the group is matched with another group
/// </summary>
/// <param name="group"></param>
/// <returns></returns>
BitMaskImpl MatchGroup(GroupImpl group);
/// <summary>
/// Returns a `BitMask` indicating all bytes in the group which are
/// `EMPTY`.
/// </summary>
/// <returns></returns>
BitMaskImpl MatchEmpty();
/// <summary>
/// Returns a `BitMask` indicating all bytes in the group which are
/// `EMPTY` or `DELETED`.
/// </summary>
/// <returns></returns>
BitMaskImpl MatchEmptyOrDeleted();
/// <summary>
/// Returns a `BitMask` indicating all bytes in the group which are full.
/// </summary>
/// <returns></returns>
BitMaskImpl MatchFull();
}
[BurstCompile]
internal struct Avx2BitMask : IBitMask<Avx2BitMask>
{
private const uint BITMASK_MASK = 0xffff_ffff;
// 256 / 8 = 32, so choose uint
internal readonly uint _data;
internal Avx2BitMask(uint data)
{
_data = data;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask Invert()
{
return new Avx2BitMask((_data ^ BITMASK_MASK));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool AnyBitSet()
{
return _data != 0;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LeadingZeros()
{
return math.lzcnt(_data);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LowestSetBit()
{
if (_data == 0)
{
return -1;
}
return LowestSetBitNonzero();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LowestSetBitNonzero()
{
return TrailingZeros();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask RemoveLowestBit()
{
return new Avx2BitMask(_data & (_data - 1));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int TrailingZeros()
{
return math.tzcnt(_data);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask And(Avx2BitMask bitMask)
{
return new Avx2BitMask((_data & bitMask._data));
}
}
// TODO: suppress default initialization.
[BurstCompile]
internal struct Avx2Group : IGroup<Avx2BitMask, Avx2Group>
{
public static int WIDTH => 256 / 8;
private readonly v256 _data;
internal Avx2Group(v256 data) { _data = data; }
public static readonly byte[] StaticEmpty = InitialStaticEmpty();
private static byte[] InitialStaticEmpty()
{
var res = new byte[WIDTH];
Array.Fill(res, EMPTY);
return res;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe Avx2Group Load(byte* ptr)
{
// unaligned
return new Avx2Group(Avx.mm256_loadu_si256((v256*)ptr));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe Avx2Group LoadAligned(byte* ptr)
{
Debug.Assert(((uint)ptr & (WIDTH - 1)) == 0);
return new Avx2Group(Avx.mm256_load_si256((v256*)ptr));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public unsafe void StoreAligned(byte* ptr)
{
Debug.Assert(((uint)ptr & (WIDTH - 1)) == 0);
Avx.mm256_store_si256((v256*)ptr, _data);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask MatchByte(byte b)
{
var compareValue = Avx.mm256_set1_epi8(b);
var cmp = Avx2.mm256_cmpeq_epi8(_data, compareValue);
return new Avx2BitMask((uint)Avx2.mm256_movemask_epi8(cmp));
}
private static readonly Avx2Group EmptyGroup = Create(EMPTY);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask MatchEmpty() => MatchGroup(EmptyGroup);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask MatchEmptyOrDeleted()
{
// high bit of each byte -> bitmask
return new Avx2BitMask((uint)Avx2.mm256_movemask_epi8(_data));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask MatchFull() => MatchEmptyOrDeleted().Invert();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2Group convert_special_to_empty_and_full_to_deleted()
{
// special = (0 > data) (treat bytes as signed)
var zero = Avx.mm256_setzero_si256();
var special = Avx2.mm256_cmpgt_epi8(zero, _data);
var hiBit = Avx.mm256_set1_epi8(0x80);
return new Avx2Group(Avx2.mm256_or_si256(special, hiBit));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Avx2Group Create(byte b) => new Avx2Group(Avx.mm256_set1_epi8(b));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Avx2BitMask MatchGroup(Avx2Group group)
{
var cmp = Avx2.mm256_cmpeq_epi8(_data, group._data);
return new Avx2BitMask((uint)Avx2.mm256_movemask_epi8(cmp));
}
}
[BurstCompile]
internal struct FallbackBitMask : IBitMask<FallbackBitMask>
{
// Why use nuint/nint?
// For 64 bit platform, we could compare 8 buckets at one time,
// For 32 bit platform, we could compare 4 buckets at one time.
// And it might be faster to access data for it is aligned, but not sure.
private readonly nuint _data;
private static nuint BITMASK_MASK => (nuint)0x8080_8080_8080_8080;
private const int BITMASK_SHIFT = 3;
internal FallbackBitMask(nuint data)
{
_data = data;
}
/// Returns a new `BitMask` with all bits inverted.
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackBitMask Invert()
{
return new FallbackBitMask(_data ^ BITMASK_MASK);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool AnyBitSet()
{
return _data != 0;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LeadingZeros()
{
#if TARGET_64BIT
return math.lzcnt((uint)this._data) >> BITMASK_SHIFT;
#else
// maigc number `32`
// type of `this._data` is `nunit`
// however, it will be tranfrom to `ulong` implicitly
// So it is 64 - 32 = 32
return (math.lzcnt((uint)_data) - 32) >> BITMASK_SHIFT;
#endif
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LowestSetBit()
{
if (_data == 0)
{
return -1;
}
return LowestSetBitNonzero();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LowestSetBitNonzero()
{
return TrailingZeros();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackBitMask RemoveLowestBit()
{
return new FallbackBitMask(_data & (_data - 1));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int TrailingZeros()
{
return math.tzcnt((uint)_data) >> BITMASK_SHIFT;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackBitMask And(FallbackBitMask bitMask)
{
return new FallbackBitMask(_data & bitMask._data);
}
}
[BurstCompile]
internal struct FallbackGroup : IGroup<FallbackBitMask, FallbackGroup>
{
public static unsafe int WIDTH => sizeof(nuint);
public static readonly byte[] static_empty = InitialStaticEmpty();
private static byte[] InitialStaticEmpty()
{
var res = new byte[WIDTH];
Array.Fill(res, EMPTY);
return res;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe FallbackGroup load(byte* ptr)
{
return new FallbackGroup(Unsafe.ReadUnaligned<nuint>(ptr));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe FallbackGroup load_aligned(byte* ptr)
{
// uint casting is OK, for WIDTH only use low 16 bits now.
Debug.Assert(((uint)ptr & (WIDTH - 1)) == 0);
return new FallbackGroup(Unsafe.Read<nuint>(ptr));
}
private static nuint repeat(byte b)
{
nuint res = 0;
for (int i = 0; i < WIDTH; i++)
{
res <<= 8;
res &= b;
}
return res;
}
private readonly nuint _data;
internal FallbackGroup(nuint data)
{
_data = data;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackGroup convert_special_to_empty_and_full_to_deleted()
{
// Map high_bit = 1 (EMPTY or DELETED) to 1111_1111
// and high_bit = 0 (FULL) to 1000_0000
//
// Here's this logic expanded to concrete values:
// let full = 1000_0000 (true) or 0000_0000 (false)
// !1000_0000 + 1 = 0111_1111 + 1 = 1000_0000 (no carry)
// !0000_0000 + 0 = 1111_1111 + 0 = 1111_1111 (no carry)
nuint full = ~_data & (nuint)0x8080_8080_8080_8080;
var q = (full >> 7);
var w = ~full + q;
return new FallbackGroup(w);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public unsafe void StoreAligned(byte* ptr)
{
// uint casting is OK, for WIDTH only use low 16 bits now.
Debug.Assert(((uint)ptr & (WIDTH - 1)) == 0);
Unsafe.Write(ptr, _data);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackBitMask MatchByte(byte b)
{
// This algorithm is derived from
// https://graphics.stanford.edu/~seander/bithacks.html##ValueInWord
var cmp = _data ^ repeat(b);
var res = (cmp - (nuint)0x0101_0101_0101_0101) & ~cmp & (nuint)0x8080_8080_8080_8080;
return new FallbackBitMask(res);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackBitMask MatchEmpty()
{
// If the high bit is set, then the byte must be either:
// 1111_1111 (EMPTY) or 1000_0000 (DELETED).
// So we can just check if the top two bits are 1 by ANDing them.
return new FallbackBitMask(_data & _data << 1 & (nuint)0x8080_8080_8080_8080);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackBitMask MatchEmptyOrDeleted()
{
// A byte is EMPTY or DELETED iff the high bit is set
return new FallbackBitMask(_data & (nuint)0x8080_8080_8080_8080);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public FallbackBitMask MatchFull()
{
return MatchEmptyOrDeleted().Invert();
}
public static FallbackGroup create(byte b)
{
throw new NotImplementedException();
}
public FallbackBitMask MatchGroup(FallbackGroup group)
{
throw new NotImplementedException();
}
}
[BurstCompile]
internal struct Sse2BitMask : IBitMask<Sse2BitMask>
{
private const ushort BITMASK_MASK = 0xffff;
// 128 / 8 = 16, so choose ushort
internal readonly ushort _data;
internal Sse2BitMask(ushort data)
{
_data = data;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask Invert()
{
return new Sse2BitMask((ushort)(_data ^ BITMASK_MASK));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public bool AnyBitSet()
{
return _data != 0;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LeadingZeros()
{
// maigc number `16`
// type of `this._data` is `short`
// however, it will be tranfrom to `uint` implicitly
// Delete the additional length
return math.lzcnt((uint)_data) - 16;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LowestSetBit()
{
if (_data == 0)
{
return -1;
}
return LowestSetBitNonzero();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int LowestSetBitNonzero()
{
return TrailingZeros();
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask RemoveLowestBit()
{
return new Sse2BitMask((ushort)(_data & (_data - 1)));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public int TrailingZeros()
{
return math.tzcnt((uint)_data);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask And(Sse2BitMask bitMask)
{
return new Sse2BitMask((ushort)(_data & bitMask._data));
}
}
// TODO: suppress default initialization.
[BurstCompile]
internal struct Sse2Group : IGroup<Sse2BitMask, Sse2Group>
{
public static int WIDTH => 128 / 8;
private readonly v128 _data;
internal Sse2Group(v128 data) { _data = data; }
public static readonly byte[] static_empty = InitialStaticEmpty();
private static byte[] InitialStaticEmpty()
{
var res = new byte[WIDTH];
Array.Fill(res, EMPTY);
return res;
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe Sse2Group load(byte* ptr)
{
// unaligned
return new Sse2Group(Sse2.loadu_si128((v128*)ptr));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static unsafe Sse2Group load_aligned(byte* ptr)
{
Debug.Assert(((uint)ptr & (WIDTH - 1)) == 0);
return new Sse2Group(Sse2.load_si128((v128*)ptr));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public unsafe void StoreAligned(byte* ptr)
{
Debug.Assert(((uint)ptr & (WIDTH - 1)) == 0);
Sse2.store_si128((v128*)ptr, _data);
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask MatchByte(byte b)
{
var compareValue = Sse2.set1_epi8((sbyte)b);
var cmp = Sse2.cmpeq_epi8(_data, compareValue);
return new Sse2BitMask((ushort)Sse2.movemask_epi8(cmp));
}
private static readonly Sse2Group EmptyGroup = Create(EMPTY);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask MatchEmpty() => MatchGroup(EmptyGroup);
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask MatchEmptyOrDeleted()
{
return new Sse2BitMask((ushort)Sse2.movemask_epi8(_data));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask MatchFull() => MatchEmptyOrDeleted().Invert();
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2Group convert_special_to_empty_and_full_to_deleted()
{
var zero = Sse2.setzero_si128();
var special = Sse2.cmpgt_epi8(zero, _data);
var hiBit = Sse2.set1_epi8(unchecked((sbyte)0x80));
return new Sse2Group(Sse2.or_si128(special, hiBit));
}
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static Sse2Group Create(byte b) => new Sse2Group(Sse2.set1_epi8((sbyte)b));
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public Sse2BitMask MatchGroup(Sse2Group group)
{
var cmp = Sse2.cmpeq_epi8(_data, group._data);
return new Sse2BitMask((ushort)Sse2.movemask_epi8(cmp));
}
}
#pragma warning disable CA1810 // Initialize reference type static fields inline
// Probe sequence based on triangular numbers, which is guaranteed (since our
// table size is a power of two) to visit every group of elements exactly once.
//
// A triangular probe has us jump by 1 more group every time. So first we
// jump by 1 group (meaning we just continue our linear scan), then 2 groups
// (skipping over 1 group), then 3 groups (skipping over 2 groups), and so on.
//
// The proof is a simple number theory question: i*(i+1)/2 can walk through the complete residue system of 2n
// to prove this, we could prove when "0 <= i <= j < 2n", "i * (i + 1) / 2 mod 2n == j * (j + 1) / 2" iff "i == j"
// sufficient: we could have `(i-j)(i+j+1)=4n*k`, k is integer. It is obvious that if i!=j, the left part is odd, but right is always even.
// So, the the only chance is i==j
// necessary: obvious
// Q.E.D.
[BurstCompile]
internal struct ProbeSeq
{
internal int pos;
private int _stride;
private int _bucket_mask;
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static ProbeSeq Create(int hash, int bucket_mask)
{
Unsafe.SkipInit(out ProbeSeq s);
s._bucket_mask = bucket_mask;
s.pos = h1(hash) & bucket_mask;
s._stride = 0;
return s;
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal void move_next()
{
// We should have found an empty bucket by now and ended the probe.
Debug.Assert(_stride <= _bucket_mask, "Went past end of probe sequence");
_stride += GROUP_WIDTH;
pos += _stride;
pos &= _bucket_mask;
}
}
static class SimdDetect
{
static bool _initialized;
static int _width; // 32=AVX2, 16=SSE2, 0=fallback
[BurstCompile(CompileSynchronously = true)]
struct ProbeJob : IJob
{
public NativeArray<int> Out;
public void Execute()
{
int w = Avx2.IsAvx2Supported ? 32 : (Sse2.IsSse2Supported ? 16 : 0);
Out[0] = w;
}
}
public static int GetWidth()
{
if (_initialized) return _width;
var tmp = new NativeArray<int>(1, Allocator.TempJob);
new ProbeJob { Out = tmp }.Run(); // or Schedule().Complete()
_width = tmp[0];
tmp.Dispose();
_initialized = true;
return _width;
}
}
internal struct SwissTableHelper
{
private static readonly int kSimdWidth = SimdDetect.GetWidth();
public static readonly int GROUP_WIDTH = (kSimdWidth != 0) ? kSimdWidth : FallbackGroup.WIDTH;
// Optional helpers to make call sites clear
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool UseAvx2() => kSimdWidth == Avx2Group.WIDTH;
[MethodImpl(MethodImplOptions.AggressiveInlining)]
internal static bool UseSse2() => kSimdWidth == Sse2Group.WIDTH;
// --- Managed side: compile, invoke, log/verify ---
#if UNITY_EDITOR
[UnityEditor.InitializeOnLoadMethod] // runs after editor domain reloads
#endif
[RuntimeInitializeOnLoadMethod(RuntimeInitializeLoadType.AfterAssembliesLoaded)] // runs on Play in Player/Editor
private static void InitSimdDetector()
{
Debug.Log($"SIMD width: {kSimdWidth}");
}
/// Control byte value for an empty bucket.
public const byte EMPTY = 0b1111_1111;
/// Control byte value for a deleted bucket.
public const byte DELETED = 0b1000_0000;
/// Checks whether a control byte represents a full bucket (top bit is clear).
public static bool is_full(byte ctrl) => (ctrl & 0x80) == 0;
/// Checks whether a control byte represents a special value (top bit is set).
public static bool is_special(byte ctrl) => (ctrl & 0x80) != 0;
/// Checks whether a special control value is EMPTY (just check 1 bit).
public static bool special_is_empty(byte ctrl)
{
Debug.Assert(is_special(ctrl));
return (ctrl & 0x01) != 0;
}
/// Checks whether a special control value is EMPTY.
// optimise: return 1 as true, 0 as false
public static int special_is_empty_with_int_return(byte ctrl)
{
Debug.Assert(is_special(ctrl));
return ctrl & 0x01;
}
/// Primary hash function, used to select the initial bucket to probe from.
public static int h1(int hash)
{
return hash;
}
/// Secondary hash function, saved in the low 7 bits of the control byte.
public static byte h2(int hash)
{
// Grab the top 7 bits of the hash.
// cast to uint to use `shr` rahther than `sar`, which makes sure the top bit of returned byte is 0.
var top7 = (uint)hash >> 25;
return (byte)top7;
}
// DISPATHCH METHODS
// Generally we do not want to duplicate code, but for performance(use struct and inline), we have to do so.
// The difference between mirror implmentations should only be `_dummyGroup` except `MoveNext`, in which we use C++ union trick
// For enumerator, which need record the current state
[StructLayout(LayoutKind.Explicit)]
internal struct BitMaskUnion
{
[FieldOffset(0)]
internal Avx2BitMask avx2BitMask;
[FieldOffset(0)]
internal Sse2BitMask sse2BitMask;
[FieldOffset(0)]
internal FallbackBitMask fallbackBitMask;
}
// maybe we should just pass bucket_mask in as parater rather than calculate
private static int GetBucketMaskFromControlsLength(int controlsLength)
{
Debug.Assert(controlsLength >= GROUP_WIDTH);
if (controlsLength == GROUP_WIDTH)
return 0;
return controlsLength - GROUP_WIDTH - 1;
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static byte[] DispatchGetEmptyControls()
{
if (UseAvx2())
{
return Avx2Group.StaticEmpty;
}
if (UseSse2())
{
return Sse2Group.static_empty;
}
return FallbackGroup.static_empty;
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static BitMaskUnion DispatchGetMatchFullBitMask(byte[] controls, int index)
{
if (UseAvx2())
{
return GetMatchFullBitMaskForAvx2(controls, index);
}
if (UseSse2())
{
return GetMatchFullBitMaskForSse2(controls, index);
}
return GetMatchFullBitMaskForFallback(controls, index);
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe BitMaskUnion GetMatchFullBitMaskForAvx2(byte[] controls, int index)
{
BitMaskUnion result = default;
fixed (byte* ctrl = &controls[index])
{
result.avx2BitMask = Avx2Group.Load(ctrl).MatchFull();
}
return result;
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe BitMaskUnion GetMatchFullBitMaskForSse2(byte[] controls, int index)
{
BitMaskUnion result = default;
fixed (byte* ctrl = &controls[index])
{
result.sse2BitMask = Sse2Group.load(ctrl).MatchFull();
}
return result;
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe BitMaskUnion GetMatchFullBitMaskForFallback(byte[] controls, int index)
{
BitMaskUnion result = default;
fixed (byte* ctrl = &controls[index])
{
result.fallbackBitMask = FallbackGroup.load(ctrl).MatchFull();
}
return result;
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref SwissTable<TKey, TValue>.Entry DispatchMoveNextDictionary<TKey, TValue>(
int version,
int tolerantVersion,
in SwissTable<TKey, TValue> dictionary,
ref int currentCtrlOffset,
ref BitMaskUnion currentBitMask
)
where TKey : notnull
{
if (UseAvx2())
{
return ref MoveNextDictionaryForAvx2(version, tolerantVersion, in dictionary, ref currentCtrlOffset, ref currentBitMask);
}
if (UseSse2())
{
return ref MoveNextDictionaryForSse2(version, tolerantVersion, in dictionary, ref currentCtrlOffset, ref currentBitMask);
}
return ref MoveNextDictionaryForFallback(version, tolerantVersion, in dictionary, ref currentCtrlOffset, ref currentBitMask);
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref SwissTable<TKey, TValue>.Entry MoveNextDictionaryForAvx2<TKey, TValue>(
int version,
int tolerantVersion,
in SwissTable<TKey, TValue> dictionary,
ref int currentCtrlOffset,
ref BitMaskUnion currentBitMask
)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
ref var realBitMask = ref currentBitMask.avx2BitMask;
if (version != dictionary._version)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion();
}
if (tolerantVersion != dictionary._tolerantVersion)
{
var newBitMask = GetMatchFullBitMaskForAvx2(controls, currentCtrlOffset).avx2BitMask;
realBitMask = realBitMask.And(newBitMask);
}
while (true)
{
var lowest_set_bit = realBitMask.LowestSetBit();
if (lowest_set_bit >= 0)
{
Debug.Assert(entries != null);
realBitMask = realBitMask.RemoveLowestBit();
ref var entry = ref entries[currentCtrlOffset + lowest_set_bit];
return ref entry;
}
currentCtrlOffset += GROUP_WIDTH;
if (currentCtrlOffset >= dictionary._buckets)
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
realBitMask = GetMatchFullBitMaskForAvx2(controls, currentCtrlOffset).avx2BitMask;
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref SwissTable<TKey, TValue>.Entry MoveNextDictionaryForSse2<TKey, TValue>(
int version,
int tolerantVersion,
in SwissTable<TKey, TValue> dictionary,
ref int currentCtrlOffset,
ref BitMaskUnion currentBitMask
)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
ref var realBitMask = ref currentBitMask.sse2BitMask;
if (version != dictionary._version)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion();
}
if (tolerantVersion != dictionary._tolerantVersion)
{
var newBitMask = GetMatchFullBitMaskForSse2(controls, currentCtrlOffset).sse2BitMask;
realBitMask = realBitMask.And(newBitMask);
}
while (true)
{
var lowest_set_bit = realBitMask.LowestSetBit();
if (lowest_set_bit >= 0)
{
Debug.Assert(entries != null);
realBitMask = realBitMask.RemoveLowestBit();
ref var entry = ref entries[currentCtrlOffset + lowest_set_bit];
return ref entry;
}
currentCtrlOffset += GROUP_WIDTH;
if (currentCtrlOffset >= dictionary._buckets)
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
realBitMask = GetMatchFullBitMaskForSse2(controls, currentCtrlOffset).sse2BitMask;
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref SwissTable<TKey, TValue>.Entry MoveNextDictionaryForFallback<TKey, TValue>(
int version,
int tolerantVersion,
in SwissTable<TKey, TValue> dictionary,
ref int currentCtrlOffset,
ref BitMaskUnion currentBitMask
)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
ref var realBitMask = ref currentBitMask.fallbackBitMask;
if (version != dictionary._version)
{
// ThrowHelper.ThrowInvalidOperationException_InvalidOperation_EnumFailedVersion();
}
if (tolerantVersion != dictionary._tolerantVersion)
{
var newBitMask = GetMatchFullBitMaskForFallback(controls, currentCtrlOffset);
realBitMask = realBitMask.And(newBitMask.fallbackBitMask);
}
while (true)
{
var lowest_set_bit = realBitMask.LowestSetBit();
if (lowest_set_bit >= 0)
{
Debug.Assert(entries != null);
realBitMask = realBitMask.RemoveLowestBit();
ref var entry = ref entries[currentCtrlOffset + lowest_set_bit];
return ref entry;
}
currentCtrlOffset += GROUP_WIDTH;
if (currentCtrlOffset >= dictionary._buckets)
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
realBitMask = GetMatchFullBitMaskForFallback(controls, currentCtrlOffset).fallbackBitMask;
}
}
// If we are inside a continuous block of Group::WIDTH full or deleted
// cells then a probe window may have seen a full block when trying to
// insert. We therefore need to keep that block non-empty so that
// lookups will continue searching to the next probe window.
//
// Note that in this context `leading_zeros` refers to the bytes at the
// end of a group, while `trailing_zeros` refers to the bytes at the
// begining of a group.
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static bool DispatchIsEraseSafeToSetEmptyControlFlag(int bucketMask, byte[] controls, int index)
{
if (UseAvx2())
{
return IsEraseSafeToSetEmptyControlFlagForAvx2(bucketMask, controls, index);
}
if (UseSse2())
{
return IsEraseSafeToSetEmptyControlFlagForSse2(bucketMask, controls, index);
}
return IsEraseSafeToSetEmptyControlFlagForFallback(bucketMask, controls, index);
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe bool IsEraseSafeToSetEmptyControlFlagForAvx2(int bucketMask, byte[] controls, int index)
{
Debug.Assert(bucketMask == GetBucketMaskFromControlsLength(controls.Length));
int indexBefore = unchecked((index - GROUP_WIDTH) & bucketMask);
fixed (byte* ptr_before = &controls[indexBefore])
fixed (byte* ptr = &controls[index])
{
var empty_before = Avx2Group.Load(ptr_before).MatchEmpty();
var empty_after = Avx2Group.Load(ptr).MatchEmpty();
return empty_before.LeadingZeros() + empty_after.TrailingZeros() < GROUP_WIDTH;
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe bool IsEraseSafeToSetEmptyControlFlagForSse2(int bucketMask, byte[] controls, int index)
{
Debug.Assert(bucketMask == GetBucketMaskFromControlsLength(controls.Length));
int indexBefore = unchecked((index - GROUP_WIDTH) & bucketMask);
fixed (byte* ptr_before = &controls[indexBefore])
fixed (byte* ptr = &controls[index])
{
var empty_before = Sse2Group.load(ptr_before).MatchEmpty();
var empty_after = Sse2Group.load(ptr).MatchEmpty();
return empty_before.LeadingZeros() + empty_after.TrailingZeros() < GROUP_WIDTH;
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe bool IsEraseSafeToSetEmptyControlFlagForFallback(int bucketMask, byte[] controls, int index)
{
Debug.Assert(bucketMask == GetBucketMaskFromControlsLength(controls.Length));
int indexBefore = unchecked((index - GROUP_WIDTH) & bucketMask);
fixed (byte* ptr_before = &controls[indexBefore])
fixed (byte* ptr = &controls[index])
{
var empty_before = FallbackGroup.load(ptr_before).MatchEmpty();
var empty_after = FallbackGroup.load(ptr).MatchEmpty();
return empty_before.LeadingZeros() + empty_after.TrailingZeros() < GROUP_WIDTH;
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static ref SwissTable<TKey, TValue>.Entry DispatchFindBucketOfDictionary<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key, int hashOfKey)
where TKey : notnull
{
if (UseAvx2())
{
return ref FindBucketOfDictionaryForAvx2(dictionary, key, hashOfKey);
}
if (UseSse2())
{
return ref FindBucketOfDictionaryForSse2(dictionary, key, hashOfKey);
}
return ref FindBucketOfDictionaryForFallback(dictionary, key, hashOfKey);
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe ref SwissTable<TKey, TValue>.Entry FindBucketOfDictionaryForAvx2<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key, int hash)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var bucketMask = dictionary.rawTable._bucket_mask;
var hashComparer = dictionary._comparer;
Debug.Assert(controls != null);
var h2_hash = h2(hash);
var targetGroup = Avx2Group.Create(h2_hash);
var probeSeq = ProbeSeq.Create(hash, bucketMask);
if (hashComparer == null)
{
if (typeof(TKey).IsValueType)
{
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Avx2Group.Load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (EqualityComparer<TKey>.Default.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Avx2Group.Load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (defaultComparer.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Avx2Group.Load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (hashComparer.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe ref SwissTable<TKey, TValue>.Entry FindBucketOfDictionaryForSse2<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key, int hash)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var bucketMask = dictionary.rawTable._bucket_mask;
var hashComparer = dictionary._comparer;
Debug.Assert(controls != null);
var h2_hash = h2(hash);
var targetGroup = Sse2Group.Create(h2_hash);
var probeSeq = ProbeSeq.Create(hash, bucketMask);
if (hashComparer == null)
{
if (typeof(TKey).IsValueType)
{
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Sse2Group.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (EqualityComparer<TKey>.Default.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Sse2Group.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (defaultComparer.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Sse2Group.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (hashComparer.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe ref SwissTable<TKey, TValue>.Entry FindBucketOfDictionaryForFallback<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key, int hash)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var bucketMask = dictionary.rawTable._bucket_mask;
var hashComparer = dictionary._comparer;
Debug.Assert(controls != null);
var h2_hash = h2(hash);
var targetGroup = FallbackGroup.create(h2_hash);
var probeSeq = ProbeSeq.Create(hash, bucketMask);
if (hashComparer == null)
{
if (typeof(TKey).IsValueType)
{
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = FallbackGroup.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (EqualityComparer<TKey>.Default.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = FallbackGroup.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (defaultComparer.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = FallbackGroup.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (hashComparer.Equals(key, entry.Key))
{
return ref entry;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return ref Unsafe.NullRef<SwissTable<TKey, TValue>.Entry>();
}
probeSeq.move_next();
}
}
}
/// <summary>
/// Find the index of given key, negative means not found.
/// </summary>
/// <typeparam name="TKey"></typeparam>
/// <typeparam name="TValue"></typeparam>
/// <param name="dictionary"></param>
/// <param name="key"></param>
/// <returns>
/// negative return value means not found
/// </returns>
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int DispatchFindBucketIndexOfDictionary<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key)
where TKey : notnull
{
if (UseAvx2())
{
return FindBucketIndexOfDictionaryForAvx2(dictionary, key);
}
if (UseSse2())
{
return FindBucketIndexOfDictionaryForSse2(dictionary, key);
}
return FindBucketIndexOfDictionaryForFallback(dictionary, key);
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int FindBucketIndexOfDictionaryForAvx2<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var bucketMask = dictionary.rawTable._bucket_mask;
var hashComparer = dictionary._comparer;
Debug.Assert(controls != null);
var hash = hashComparer == null ? key.GetHashCode() : hashComparer.GetHashCode(key);
var h2_hash = h2(hash);
var targetGroup = Avx2Group.Create(h2_hash);
var probeSeq = ProbeSeq.Create(hash, bucketMask);
if (hashComparer == null)
{
if (typeof(TKey).IsValueType)
{
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Avx2Group.Load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (EqualityComparer<TKey>.Default.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Avx2Group.Load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (defaultComparer.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Avx2Group.Load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (hashComparer.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int FindBucketIndexOfDictionaryForSse2<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var bucketMask = dictionary.rawTable._bucket_mask;
var hashComparer = dictionary._comparer;
Debug.Assert(controls != null);
var hash = hashComparer == null ? key.GetHashCode() : hashComparer.GetHashCode(key);
var h2_hash = h2(hash);
var targetGroup = Sse2Group.Create(h2_hash);
var probeSeq = ProbeSeq.Create(hash, bucketMask);
if (hashComparer == null)
{
if (typeof(TKey).IsValueType)
{
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Sse2Group.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (EqualityComparer<TKey>.Default.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Sse2Group.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (defaultComparer.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = Sse2Group.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (hashComparer.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int FindBucketIndexOfDictionaryForFallback<TKey, TValue>(SwissTable<TKey, TValue> dictionary, TKey key)
where TKey : notnull
{
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var bucketMask = dictionary.rawTable._bucket_mask;
var hashComparer = dictionary._comparer;
Debug.Assert(controls != null);
var hash = hashComparer == null ? key.GetHashCode() : hashComparer.GetHashCode(key);
var h2_hash = h2(hash);
var targetGroup = FallbackGroup.create(h2_hash);
var probeSeq = ProbeSeq.Create(hash, bucketMask);
if (hashComparer == null)
{
if (typeof(TKey).IsValueType)
{
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = FallbackGroup.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (EqualityComparer<TKey>.Default.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
EqualityComparer<TKey> defaultComparer = EqualityComparer<TKey>.Default;
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = FallbackGroup.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (defaultComparer.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
fixed (byte* ptr = &controls[0])
{
while (true)
{
var group = FallbackGroup.load(ptr + probeSeq.pos);
var bitmask = group.MatchGroup(targetGroup);
// TODO: Iterator and performance, if not influence, iterator would be clearer.
while (bitmask.AnyBitSet())
{
// there must be set bit
Debug.Assert(entries != null);
var bit = bitmask.LowestSetBitNonzero();
bitmask = bitmask.RemoveLowestBit();
var index = (probeSeq.pos + bit) & bucketMask;
ref var entry = ref entries[index];
if (hashComparer.Equals(key, entry.Key))
{
return index;
}
}
if (group.MatchEmpty().AnyBitSet())
{
return -1;
}
probeSeq.move_next();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static void DispatchCopyToArrayFromDictionaryWorker<TKey, TValue>(SwissTable<TKey, TValue> dictionary, KeyValuePair<TKey, TValue>[] destArray, int index)
where TKey : notnull
{
if (UseAvx2())
{
CopyToArrayFromDictionaryWorkerForAvx2(dictionary, destArray, index);
}
else
if (UseSse2())
{
CopyToArrayFromDictionaryWorkerForSse2(dictionary, destArray, index);
}
else
{
CopyToArrayFromDictionaryWorkerForFallback(dictionary, destArray, index);
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void CopyToArrayFromDictionaryWorkerForAvx2<TKey, TValue>(SwissTable<TKey, TValue> dictionary, KeyValuePair<TKey, TValue>[] destArray, int index)
where TKey : notnull
{
int offset = 0;
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var buckets = entries?.Length ?? 0;
Debug.Assert(controls != null);
fixed (byte* ptr = &controls[0])
{
var bitMask = Avx2Group.Load(ptr).MatchFull();
while (true)
{
var lowestSetBit = bitMask.LowestSetBit();
if (lowestSetBit >= 0)
{
Debug.Assert(entries != null);
bitMask = bitMask.RemoveLowestBit();
ref var entry = ref entries[offset + lowestSetBit];
destArray[index++] = new KeyValuePair<TKey, TValue>(entry.Key, entry.Value);
continue;
}
offset += GROUP_WIDTH;
if (offset >= buckets)
{
break;
}
bitMask = Avx2Group.Load(ptr + offset).MatchFull();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void CopyToArrayFromDictionaryWorkerForSse2<TKey, TValue>(SwissTable<TKey, TValue> dictionary, KeyValuePair<TKey, TValue>[] destArray, int index)
where TKey : notnull
{
int offset = 0;
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var buckets = entries?.Length ?? 0;
Debug.Assert(controls != null);
fixed (byte* ptr = &controls[0])
{
var bitMask = Sse2Group.load(ptr).MatchFull();
while (true)
{
var lowestSetBit = bitMask.LowestSetBit();
if (lowestSetBit >= 0)
{
Debug.Assert(entries != null);
bitMask = bitMask.RemoveLowestBit();
ref var entry = ref entries[offset + lowestSetBit];
destArray[index++] = new KeyValuePair<TKey, TValue>(entry.Key, entry.Value);
continue;
}
offset += GROUP_WIDTH;
if (offset >= buckets)
{
break;
}
bitMask = Sse2Group.load(ptr + offset).MatchFull();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe void CopyToArrayFromDictionaryWorkerForFallback<TKey, TValue>(SwissTable<TKey, TValue> dictionary, KeyValuePair<TKey, TValue>[] destArray, int index)
where TKey : notnull
{
int offset = 0;
var controls = dictionary.rawTable._controls;
var entries = dictionary.rawTable._entries;
var buckets = entries?.Length ?? 0;
Debug.Assert(controls != null);
fixed (byte* ptr = &controls[0])
{
var bitMask = FallbackGroup.load(ptr).MatchFull();
while (true)
{
var lowestSetBit = bitMask.LowestSetBit();
if (lowestSetBit >= 0)
{
Debug.Assert(entries != null);
bitMask = bitMask.RemoveLowestBit();
ref var entry = ref entries[offset + lowestSetBit];
destArray[index++] = new KeyValuePair<TKey, TValue>(entry.Key, entry.Value);
continue;
}
offset += GROUP_WIDTH;
if (offset >= buckets)
{
break;
}
bitMask = FallbackGroup.load(ptr + offset).MatchFull();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
public static int DispatchFindInsertSlot(int hash, byte[] contorls, int bucketMask)
{
if (UseAvx2())
{
return FindInsertSlotForAvx2(hash, contorls, bucketMask);
}
if (UseSse2())
{
return FindInsertSlotForSse2(hash, contorls, bucketMask);
}
return FindInsertSlotForFallback(hash, contorls, bucketMask);
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int FindInsertSlotForAvx2(int hash, byte[] contorls, int bucketMask)
{
Debug.Assert(bucketMask == GetBucketMaskFromControlsLength(contorls.Length));
ProbeSeq probeSeq = ProbeSeq.Create(hash, bucketMask);
fixed (byte* ptr = &contorls[0])
{
while (true)
{
// TODO: maybe we should lock even fix the whole loop.
// I am not sure which would be faster.
var bit = Avx2Group.Load(ptr + probeSeq.pos)
.MatchEmptyOrDeleted()
.LowestSetBit();
if (bit >= 0)
{
var result = (probeSeq.pos + bit) & bucketMask;
// In tables smaller than the group width, trailing control
// bytes outside the range of the table are filled with
// EMPTY entries. These will unfortunately trigger a
// match, but once masked may point to a full bucket that
// is already occupied. We detect this situation here and
// perform a second scan starting at the begining of the
// table. This second scan is guaranteed to find an empty
// slot (due to the load factor) before hitting the trailing
// control bytes (containing EMPTY).
if (!is_full(*(ptr + result)))
{
return result;
}
Debug.Assert(bucketMask < GROUP_WIDTH);
Debug.Assert(probeSeq.pos != 0);
return Avx2Group.Load(ptr)
.MatchEmptyOrDeleted()
.LowestSetBitNonzero();
}
probeSeq.move_next();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int FindInsertSlotForSse2(int hash, byte[] contorls, int bucketMask)
{
Debug.Assert(bucketMask == GetBucketMaskFromControlsLength(contorls.Length));
ProbeSeq probeSeq = ProbeSeq.Create(hash, bucketMask);
fixed (byte* ptr = &contorls[0])
{
while (true)
{
// TODO: maybe we should lock even fix the whole loop.
// I am not sure which would be faster.
var bit = Sse2Group.load(ptr + probeSeq.pos)
.MatchEmptyOrDeleted()
.LowestSetBit();
if (bit >= 0)
{
var result = (probeSeq.pos + bit) & bucketMask;
// In tables smaller than the group width, trailing control
// bytes outside the range of the table are filled with
// EMPTY entries. These will unfortunately trigger a
// match, but once masked may point to a full bucket that
// is already occupied. We detect this situation here and
// perform a second scan starting at the begining of the
// table. This second scan is guaranteed to find an empty
// slot (due to the load factor) before hitting the trailing
// control bytes (containing EMPTY).
if (!is_full(*(ptr + result)))
{
return result;
}
Debug.Assert(bucketMask < GROUP_WIDTH);
Debug.Assert(probeSeq.pos != 0);
return Sse2Group.load(ptr)
.MatchEmptyOrDeleted()
.LowestSetBitNonzero();
}
probeSeq.move_next();
}
}
}
[SkipLocalsInit]
[MethodImpl(MethodImplOptions.AggressiveInlining)]
private static unsafe int FindInsertSlotForFallback(int hash, byte[] contorls, int bucketMask)
{
Debug.Assert(bucketMask == GetBucketMaskFromControlsLength(contorls.Length));
ProbeSeq probeSeq = ProbeSeq.Create(hash, bucketMask);
fixed (byte* ptr = &contorls[0])
{
while (true)
{
// TODO: maybe we should lock even fix the whole loop.
// I am not sure which would be faster.
var bit = FallbackGroup.load(ptr + probeSeq.pos)
.MatchEmptyOrDeleted()
.LowestSetBit();
if (bit >= 0)
{
var result = (probeSeq.pos + bit) & bucketMask;
// In tables smaller than the group width, trailing control
// bytes outside the range of the table are filled with
// EMPTY entries. These will unfortunately trigger a
// match, but once masked may point to a full bucket that
// is already occupied. We detect this situation here and
// perform a second scan starting at the begining of the
// table. This second scan is guaranteed to find an empty
// slot (due to the load factor) before hitting the trailing
// control bytes (containing EMPTY).
if (!is_full(*(ptr + result)))
{
return result;
}
Debug.Assert(bucketMask < GROUP_WIDTH);
Debug.Assert(probeSeq.pos != 0);
return FallbackGroup.load(ptr)
.MatchEmptyOrDeleted()
.LowestSetBitNonzero();
}
probeSeq.move_next();
}
}
}
}
/// <summary>
/// Used internally to control behavior of insertion into a <see cref="Dictionary{TKey, TValue}"/> or <see cref="HashSet{T}"/>.
/// </summary>
internal enum InsertionBehavior : byte
{
/// <summary>
/// The default insertion behavior.
/// </summary>
None = 0,
/// <summary>
/// Specifies that an existing entry with the same key should be overwritten if encountered.
/// </summary>
OverwriteExisting = 1,
/// <summary>
/// Specifies that if an existing entry with the same key is encountered, an exception should be thrown.
/// </summary>
ThrowOnExisting = 2
}
}
@SolidAlloy
Copy link
Author

A quick port of https://github.com/ShuiRuTian/SwissTable to Unity and Burst for performance testing purposes. I tested with the capacity of 1024 as this was my use case. The performance against Dictionary and a much less complicated Robin Hood hashtable with backshift deletions:
---------------- Dictionary - Robin Hood - Swiss Table
Add: ------------ 1x ----------- 0.65x ---------6.6x
Get existing: --- 1x ----------- 0.5x ---------- 1.5x
Get missing: --- 1x ----------- 0.65x --------- 10.5x
Resize: --------- 1x ------------ 1x ------------ 1.5x

I don't know why it's so bad. Perhaps I made some dumb mistake when porting.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment