summaryrefslogtreecommitdiff
path: root/set.c
blob: e69cef006dfa85b69c0549ac82852f64de7df183 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
#include "set.h"
#include "crash.h"
#include <stdlib.h>

#define DEFAULT_CAPACITY 16
#define TOMBSTONE_VAL ((void*)-1)

void set_init(
    set_t* set,
    hash_func_t hash_func,
    eq_func_t eq_func,
    double load_limit
) {
    set_init_capacity(set, hash_func, eq_func, load_limit, DEFAULT_CAPACITY);
}

void set_init_capacity(
    set_t* set,
    hash_func_t hash_func,
    eq_func_t eq_func,
    double load_limit,
    size_t capacity
) {
    set->hash_func = hash_func;
    set->eq_func = eq_func;
    set->load_limit = load_limit;
    set->size = 0;
    set->__num_buckets = capacity / load_limit + 1;
    set->__buckets = calloc(set->__num_buckets, sizeof(void*));

    if (set->__buckets == NULL)
        crash("Out of memory allocating %ld hash buckets\n", set->__num_buckets);
}

void set_destroy(set_t* set) {
    free(set->__buckets);
}

static size_t fetch_set_idx(
    const set_t* set,
    const void* key,
    bool fetch_tombstones
) {
    if (key == NULL || key == TOMBSTONE_VAL)
        crash("safec sets do not support non-address keys: %p.\n", key);
    size_t base = set->hash_func(key) % set->__num_buckets;
    size_t idx = base, i = 0, offset = 0;

    for (;;) {
        void* cur = set->__buckets[idx];
        if (cur == NULL || (fetch_tombstones && cur == TOMBSTONE_VAL)
                || (cur != TOMBSTONE_VAL && set->eq_func(key, cur)))
            return idx;
        if (offset >= set->__num_buckets) return set->__num_buckets;
        i++;
        offset = i * i;
        idx = (base + offset) % set->__num_buckets;
    }
}

bool set_contains(const set_t* set, const void* key) {
    size_t idx = fetch_set_idx(set, key, false);
    return idx < set->__num_buckets && set->__buckets[idx] != NULL;
}

void* set_get(const set_t* set, const void* key) {
    return set_get_or_default(set, key, NULL);
}

void* set_get_or_default(const set_t* set, const void* key, void* default_key) {
    size_t idx = fetch_set_idx(set, key, false);
    if (idx >= set->__num_buckets) return default_key;
    void* found = set->__buckets[idx];
    return found == NULL ? default_key : found;
}

static void rehash_set(set_t* set) {
    size_t old_num_buckets = set->__num_buckets;
    void** old_buckets = set->__buckets;
    set->__num_buckets <<= 1;
    set->__buckets = calloc(set->__num_buckets, sizeof(void*));
    if (set->__buckets == NULL)
        crash(
            "Out of memory allocating %ld hash buckets\n", set->__num_buckets);

    for (size_t i = 0; i < old_num_buckets; i++) {
        void* key = old_buckets[i];
        if (key == NULL) continue;
        size_t idx = fetch_set_idx(set, key, true);
        if (idx >= set->__num_buckets)
            crash(
                "Set failed rehashing, likely due to bad hash function.\n");
        set->__buckets[idx] = key;
    }
    free(old_buckets);
}

static size_t fetch_set_idx_rehashing(set_t* set, void* key) {
    size_t idx = fetch_set_idx(set, key, true);
    if ((double) ++set->size / set->__num_buckets > set->load_limit
            || idx >= set->__num_buckets) {
        rehash_set(set);
        idx = fetch_set_idx(set, key, true);
        if (idx >= set->__num_buckets)
            crash(
                "Set still full after rehashing, "
                "likely due to bad hash function.\n");
    }
    return idx;
}

void* set_add(set_t* set, void* key) {
    size_t idx = fetch_set_idx_rehashing(set, key);
    void* old_key = set->__buckets[idx];
    if (old_key != NULL) return old_key;
    set->__buckets[idx] = key;
    return key;
}

void set_put(set_t* set, void* key) {
    size_t idx = fetch_set_idx_rehashing(set, key);
    set->__buckets[idx] = key;
}

void* set_remove(set_t* set, const void* key) {
    size_t idx = fetch_set_idx(set, key, false);
    void* old_key = set->__buckets[idx];
    if (old_key != NULL) set->size--;
    set->__buckets[idx] = TOMBSTONE_VAL;
    return old_key;   
}

void set_foreach(set_t* set, set_foreach_func_t foreach_func, void* data) {
    for (size_t i = 0; i < set->__num_buckets; i++) {
        void* key = set->__buckets[i];
        if (key == NULL || key == TOMBSTONE_VAL) continue;
        foreach_func(key, data);
    }
}