#include "set.h" #include "crash.h" #include #define DEFAULT_CAPACITY 16 #define TOMBSTONE_VAL ((void*)-1) void set_init( set_t* set, hash_func_t hash_func, eq_func_t eq_func, double load_limit ) { set_init_capacity(set, hash_func, eq_func, load_limit, DEFAULT_CAPACITY); } void set_init_capacity( set_t* set, hash_func_t hash_func, eq_func_t eq_func, double load_limit, size_t capacity ) { set->hash_func = hash_func; set->eq_func = eq_func; set->load_limit = load_limit; set->size = 0; set->__num_buckets = capacity / load_limit + 1; set->__buckets = calloc(set->__num_buckets, sizeof(void*)); if (set->__buckets == NULL) crash("Out of memory allocating %ld hash buckets\n", set->__num_buckets); } void set_destroy(set_t* set) { free(set->__buckets); } static size_t fetch_set_idx( const set_t* set, const void* key, bool fetch_tombstones ) { if (key == NULL || key == TOMBSTONE_VAL) crash("safec sets do not support non-address keys: %p.\n", key); size_t base = set->hash_func(key) % set->__num_buckets; size_t idx = base, i = 0, offset = 0; for (;;) { void* cur = set->__buckets[idx]; if (cur == NULL || (fetch_tombstones && cur == TOMBSTONE_VAL) || (cur != TOMBSTONE_VAL && set->eq_func(key, cur))) return idx; if (offset >= set->__num_buckets) return set->__num_buckets; i++; offset = i * i; idx = (base + offset) % set->__num_buckets; } } bool set_contains(const set_t* set, const void* key) { size_t idx = fetch_set_idx(set, key, false); return idx < set->__num_buckets && set->__buckets[idx] != NULL; } void* set_get(const set_t* set, const void* key) { return set_get_or_default(set, key, NULL); } void* set_get_or_default(const set_t* set, const void* key, void* default_key) { size_t idx = fetch_set_idx(set, key, false); if (idx >= set->__num_buckets) return default_key; void* found = set->__buckets[idx]; return found == NULL ? default_key : found; } static void rehash_set(set_t* set) { size_t old_num_buckets = set->__num_buckets; void** old_buckets = set->__buckets; set->__num_buckets <<= 1; set->__buckets = calloc(set->__num_buckets, sizeof(void*)); if (set->__buckets == NULL) crash( "Out of memory allocating %ld hash buckets\n", set->__num_buckets); for (size_t i = 0; i < old_num_buckets; i++) { void* key = old_buckets[i]; if (key == NULL) continue; size_t idx = fetch_set_idx(set, key, true); if (idx >= set->__num_buckets) crash( "Set failed rehashing, likely due to bad hash function.\n"); set->__buckets[idx] = key; } free(old_buckets); } static size_t fetch_set_idx_rehashing(set_t* set, void* key) { size_t idx = fetch_set_idx(set, key, true); if ((double) ++set->size / set->__num_buckets > set->load_limit || idx >= set->__num_buckets) { rehash_set(set); idx = fetch_set_idx(set, key, true); if (idx >= set->__num_buckets) crash( "Set still full after rehashing, " "likely due to bad hash function.\n"); } return idx; } void* set_add(set_t* set, void* key) { size_t idx = fetch_set_idx_rehashing(set, key); void* old_key = set->__buckets[idx]; if (old_key != NULL) return old_key; set->__buckets[idx] = key; return key; } void set_put(set_t* set, void* key) { size_t idx = fetch_set_idx_rehashing(set, key); set->__buckets[idx] = key; } void* set_remove(set_t* set, const void* key) { size_t idx = fetch_set_idx(set, key, false); void* old_key = set->__buckets[idx]; if (old_key != NULL) set->size--; set->__buckets[idx] = TOMBSTONE_VAL; return old_key; } void set_foreach(set_t* set, set_foreach_func_t foreach_func, void* data) { for (size_t i = 0; i < set->__num_buckets; i++) { void* key = set->__buckets[i]; if (key == NULL || key == TOMBSTONE_VAL) continue; foreach_func(key, data); } }