用 C++ 實作 Trie

Posted by blueskyson on September 15, 2021

Trie 簡介

Trie,又稱字首樹或字典樹,規則為將單字柴成一個一個字元,每個字元代表一個節點,依照字元在單字中的順序往下長成一棵樹。

Trie 實作方式

首先定義節點物件 node,其必須包含可以往下長的節點 child,在這裡我使用了 STL 的 unordered_map 來實作 childnode 還需要包含一個布林值 is_word 判別從根走到這個節點是否能讀作一個完整的單字。Trie 物件中必須包含一個根節點 root,所有插入、移除、搜尋都必須從 root 往下操作、依照字元順序走訪這棵樹。

C++ 實作

以下原始碼也放在我的 github repo https://github.com/blueskyson/data-structure/tree/master/cpp/trie

trie.h

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
#include <unordered_map>
#include <iostream>
#include <string>
#include <vector>

class trie {
public:
    class node {
    public:
        node();
        node(int);
        ~node();
        int depth;
        bool is_word;
        std::unordered_map<char, node*> child;
    };

    trie();
    void insert(std::string);
    bool search(std::string);
    void remove(std::string);
    void remove_start_with(std::string);
    std::vector<std::string> list_start_with(std::string);
    std::vector<std::string> list_all();
    int max_depth();
    int word_number();

private:
    void push_word_recursive(node*, std::string, std::vector<std::string>*);
    void max_depth_recursive(node*, int*);
    void word_num_recursve(node*, int*);
    node* root;
};

trie.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
#include "trie.h"

using std::string;
using std::unordered_map;
using std::vector;

trie::node::node() {
    depth = 0;
    is_word = false;
    child = unordered_map<char, node*>();
}

trie::node::node(int depth): depth(depth) {
    is_word = false;
    child = unordered_map<char, node*>();
}

trie::node::~node() {
    for (std::pair<char, node*> pair : child) {
        delete pair.second;
    }
}

trie::trie() {
    root = new node();
}

void trie::insert(string word) {
    node* n = root;
    for (char c : word) {
        if (n->child.find(c) == n->child.end()) {
            n->child[c] = new node(n->depth + 1);
        }
        n = n->child[c];
    }
    n->is_word = true;
}

bool trie::search(string word) {
    node* n = root;
    for (char c : word) {
        if (n->child.find(c) == n->child.end()) {
            return false;
        }
        n = n->child[c];
    }
    return n->is_word;
}

void trie::remove(string word) {
    node* n = root;
    for (char c : word) {
        if (n->child.find(c) == n->child.end()) {
            return;
        }
        n = n->child[c];
    }
    n->is_word = false;
}

void trie::remove_start_with(string word) {
    node* n = root;
    for (char c : word) {
        if (n->child.find(c) == n->child.end()) {
            return;
        }
        n = n->child[c];
    }
    for (std::pair<char, node*> pair : n->child) {
        delete pair.second;
    }
    n->child = unordered_map<char, node*>();
    n->is_word = false;
}

void trie::push_word_recursive(node* n, std::string word, std::vector<std::string>* list) {
    if (n->is_word) {
        list->push_back(word);
    }
    for (std::pair<char, node*> pair : n->child) {
        push_word_recursive(pair.second, word + pair.first, list);
    }
}

vector<string> trie::list_start_with(string word) {
    node* n = root;
    vector<string> list;
    for (char c : word) {
        if (n->child.find(c) == n->child.end()) {
            return list;
        }
        n = n->child[c];
    }

    push_word_recursive(n, word, &list);
    return list;
}

vector<string> trie::list_all() {
    vector<string> list;
    push_word_recursive(root, "", &list);
    return list;
}

void trie::max_depth_recursive(node* n, int* max) {
    if (n->is_word && n->depth > *max) {
        *max = n->depth;
    }
    for (std::pair<char, node*> pair : n->child) {
        max_depth_recursive(pair.second, max);
    }
}

int trie::max_depth() {
    int max = 0;
    max_depth_recursive(root, &max);
    return max;
}

void trie::word_num_recursve(node* n, int* num) {
    if (n->is_word) {
        *num = *num + 1;
    }
    for (std::pair<char, node*> pair : n->child) {
        word_num_recursve(pair.second, num);
    }
}

int trie::word_number() {
    int num = 0;
    word_num_recursve(root, &num);
    return num;
}

使用範例

example.cpp

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
#include <iostream>
#include "trie.h"
#include <string>
#include <vector>

using namespace std;

void test_search(trie t, string s) {
    cout << "Does \"" + s + "\" exist? "
         << t.search(s) << endl;
}

int main() {
    trie t;

    /* test insert and list_all */
    t.insert("apple");
    t.insert("application");
    t.insert("apt");
    t.insert("banana");
    t.insert("babara");
    cout << "All words:" << endl;
    for (string s : t.list_all()) {
        cout << s << endl;
    }
    cout << endl;

    /* test max_depth and word_number */
    cout << "Max depth (the longest word length): "
         << t.max_depth() << endl;
    cout << "Word number: "
         << t.word_number() << endl << endl;

    /* test search */
    test_search(t, "apple");
    test_search(t, "application");
    test_search(t, "apt");
    test_search(t, "app");
    cout << endl;

    /*test list_start_with */
    cout << "Start with \"app\":" << endl;
    for (string s : t.list_start_with("app")) {
        cout << s << endl;
    }
    cout << endl;
    cout << "Start with \"ba\":" << endl;
    for (string s : t.list_start_with("ba")) {
        cout << s << endl;
    }
    cout << endl;

    /* test remove */
    test_search(t, "banana");
    cout << "Remove \"banana\"" << endl;
    t.remove("banana");
    test_search(t, "banana");
    cout << endl;

    /* test remove_start_with */
    cout << "Remove words starting with \"app\"" << endl;
    t.remove_start_with("app");
    test_search(t, "apple");
    test_search(t, "application");
    test_search(t, "apt");
    cout << endl;

    /* test max_depth and word_number again */
    cout << "All words:" << endl;
    for (string s : t.list_all()) {
        cout << s << endl;
    }
    cout << endl;
    cout << "Max depth (the longest word length): "
         << t.max_depth() << endl;
    cout << "Word number: "
         << t.word_number() << endl << endl;

    return 0;
}

執行

$ g++ example.cpp bloom.cpp -o a
$ ./a
All words:
babara
banana
apt
application
apple

Max depth (the longest word length): 11
Word number: 5

Does "apple" exist? 1
Does "application" exist? 1
Does "apt" exist? 1
Does "app" exist? 0

Start with "app":
application
apple

Start with "ba":
babara
banana

Does "banana" exist? 1
Remove "banana"
Does "banana" exist? 0

Remove words starting with "app"
Does "apple" exist? 0
Does "application" exist? 0
Does "apt" exist? 1

All words:
babara
apt

Max depth (the longest word length): 6
Word number: 2