Skip to content

Mitaka Club

正規表現マッチャをNFAで書く

Regular Expression Matching Can Be Simple And Fast を読んでいる。正規表現を NFA に変換して評価するが、a?^na^n という正規表現に対して recursive backtracking なアプローチだと全てのノードを走査することがあるので O(2^n) の計算量がかかる一方で、ここで紹介されている Ken Thompson's algorithm では各ステップごとにありうるすべての状態をリストに保持するので O(mn) で済むとされている(テキストのサイズが m)。

一旦正規表現から NFA を作るところまで読み終わって、いつも通り Rust で書き直した。簡単な正規表現、テスト抜きで300行くらいだしオートマトンの実例としていい題材だった。サポートされている記法は以下の通り。

文字意味
c文字cそのもの
|または
+直前の文字の1回以上の反復
*直前の文字の0回以上の反復
?直前の文字の0 or 1回の表示
use std::cell::{Cell, RefCell};
use std::rc::Rc;
use std::vec;

fn main() {
    println!("Hello, world!");
}

fn lexer(s: &str) -> Vec<char> {
    let mut out: Vec<char> = Vec::new();

    let s = s.to_string();
    let mut iter = s.chars().peekable();
    while let Some(c) = iter.next() {
        out.push(c);

        if c == '(' || c == '|' {
            continue;
        }

        if let Some(peek) = iter.peek() {
            match peek {
                '|' | '*' | '+' | '?' | ')' => {}
                _ => out.push('.'),
            }
        }
    }

    out
}

fn re2post(re: Vec<char>) -> Vec<char> {
    fn precedence(c: char) -> usize {
        match c {
            '(' | ')' => 1,
            '|' => 2,
            '.' => 3, // concat
            '*' | '+' | '?' => 4,
            _ => 6,
        }
    }

    let mut out: Vec<char> = Vec::new();
    let mut stack: Vec<char> = Vec::new();

    for c in re.iter() {
        match c {
            '(' => stack.push(*c),
            ')' => {
                while stack.last().is_some() && *stack.last().unwrap() != '(' {
                    let op = stack.pop().unwrap();
                    out.push(op);
                }
                // pop '('
                stack.pop();
            }
            '.' | '*' | '+' | '?' | '|' => {
                while let Some(last) = stack.last() {
                    if precedence(*last) >= precedence(*c) {
                        let op = stack.pop().unwrap();
                        out.push(op);
                    } else {
                        break;
                    }
                }
                stack.push(*c);
            }
            _ => out.push(*c),
        }
    }

    while stack.last().is_some() {
        let op = stack.pop().unwrap();
        out.push(op);
    }

    out
}

#[derive(Debug)]
struct State {
    kind: StateKind,
    out: Rc<RefCell<Option<Rc<State>>>>,
    out1: Rc<RefCell<Option<Rc<State>>>>,
    list_id: Cell<usize>,
}

#[derive(Debug)]
enum StateKind {
    Lit(char),
    Split,
    Match,
}

#[derive(Debug)]
struct Fragment {
    start: Rc<State>,
    outs: Vec<Rc<RefCell<Option<Rc<State>>>>>,
}

fn post2nfa(post: Vec<char>) -> Rc<State> {
    fn patch(outs: Vec<Rc<RefCell<Option<Rc<State>>>>>, start: Rc<State>) {
        for out in outs {
            *out.borrow_mut() = Some(Rc::clone(&start));
        }
    }

    let mut stack: Vec<Fragment> = Vec::new();

    for c in post {
        match c {
            '.' => {
                let e2 = stack.pop().unwrap();
                let e1 = stack.pop().unwrap();
                patch(e1.outs.clone(), Rc::clone(&e2.start));
                stack.push(Fragment {
                    start: e1.start,
                    outs: e2.outs,
                })
            }
            '|' => {
                let e2 = stack.pop().unwrap();
                let e1 = stack.pop().unwrap();
                let s = State {
                    kind: StateKind::Split,
                    out: Rc::new(RefCell::new(Some(Rc::clone(&e1.start)))),
                    out1: Rc::new(RefCell::new(Some(Rc::clone(&e2.start)))),
                    list_id: Cell::new(0),
                };
                stack.push(Fragment {
                    start: Rc::new(s),
                    outs: e1
                        .outs
                        .iter()
                        .chain(e2.outs.iter())
                        .map(|out| Rc::clone(out))
                        .collect::<Vec<_>>(),
                })
            }
            '?' => {
                let e = stack.pop().unwrap();
                let s = State {
                    kind: StateKind::Split,
                    out: Rc::new(RefCell::new(Some(Rc::clone(&e.start)))),
                    out1: Rc::new(RefCell::new(None)),
                    list_id: Cell::new(0),
                };
                let outs = e
                    .outs
                    .iter()
                    .chain(vec![Rc::clone(&s.out1)].iter())
                    .map(|out| Rc::clone(out))
                    .collect::<Vec<_>>();
                stack.push(Fragment {
                    start: Rc::new(s),
                    outs,
                })
            }
            '*' => {
                let e = stack.pop().unwrap();
                let s = Rc::new(State {
                    kind: StateKind::Split,
                    out: Rc::new(RefCell::new(Some(Rc::clone(&e.start)))),
                    out1: Rc::new(RefCell::new(None)),
                    list_id: Cell::new(0),
                });
                patch(e.outs, Rc::clone(&s));
                let out = Rc::clone(&s.out1);
                stack.push(Fragment {
                    start: s,
                    outs: vec![out],
                })
            }
            '+' => {
                let e = stack.pop().unwrap();
                let s = Rc::new(State {
                    kind: StateKind::Split,
                    out: Rc::new(RefCell::new(Some(Rc::clone(&e.start)))),
                    out1: Rc::new(RefCell::new(None)),
                    list_id: Cell::new(0),
                });
                patch(e.outs, Rc::clone(&s));
                let out = Rc::clone(&s.out1);
                stack.push(Fragment {
                    start: e.start,
                    outs: vec![out],
                })
            }
            _ => {
                let s = State {
                    kind: StateKind::Lit(c),
                    out: Rc::new(RefCell::new(None)),
                    out1: Rc::new(RefCell::new(None)),
                    list_id: Cell::new(0),
                };

                let out = Rc::clone(&s.out);
                let f = Fragment {
                    start: Rc::new(s),
                    outs: vec![out],
                };
                stack.push(f);
            }
        }
    }

    let e = stack.pop().unwrap();
    patch(
        e.outs,
        Rc::new(State {
            kind: StateKind::Match,
            out: Rc::new(RefCell::new(None)),
            out1: Rc::new(RefCell::new(None)),
            list_id: Cell::new(0),
        }),
    );
    e.start
}

fn matches(start: Rc<State>, s: &str) -> bool {
    let chars = s.to_string().chars().collect::<Vec<char>>();
    let mut cnt: usize = 1;
    let mut clist: Vec<Rc<State>> = Vec::new();
    add_state(&mut clist, &start, &mut cnt);
    cnt += 1;
    let mut nlist: Vec<Rc<State>> = Vec::new();

    for c in chars {
        step(&clist, &mut nlist, c, &mut cnt);
        std::mem::swap(&mut clist, &mut nlist);
    }

    is_match(&clist)
}

fn is_match(nlist: &[Rc<State>]) -> bool {
    for state in nlist {
        if let StateKind::Match = state.kind {
            return true;
        }
    }

    false
}

fn step(clist: &[Rc<State>], nlist: &mut Vec<Rc<State>>, c: char, cnt: &mut usize) {
    *cnt += 1;
    nlist.clear();
    for state in clist {
        match state.kind {
            StateKind::Lit(lit) if lit == c => {
                let s = Rc::clone(state);
                add_state(nlist, s.out.borrow_mut().as_ref().unwrap(), cnt);
            }
            _ => {}
        }
    }
}

fn add_state(list: &mut Vec<Rc<State>>, state: &Rc<State>, cnt: &mut usize) {
    if state.list_id.get() == *cnt {
        return;
    }

    if let StateKind::Split = state.kind {
        add_state(list, state.out.borrow_mut().as_ref().unwrap(), cnt);
        add_state(list, state.out1.borrow_mut().as_ref().unwrap(), cnt);
    }

    list.push(Rc::clone(state));
}

mod test {
    use super::*;

    #[test]
    fn lexer_test() {
        assert_eq!(lexer("abc").iter().collect::<String>(), "a.b.c");
        assert_eq!(lexer("ab|a").iter().collect::<String>(), "a.b|a");
        assert_eq!(lexer("ab+c").iter().collect::<String>(), "a.b+.c");
        assert_eq!(lexer("a(bb)+a").iter().collect::<String>(), "a.(b.b)+.a");
    }

    #[test]
    fn re2postfix_test() {
        assert_eq!(re2post(lexer("abc")).iter().collect::<String>(), "ab.c.");
        assert_eq!(re2post(lexer("ab|a")).iter().collect::<String>(), "ab.a|");
        assert_eq!(re2post(lexer("ab+c")).iter().collect::<String>(), "ab+.c.");
        assert_eq!(
            re2post(lexer("a(bb)+a")).iter().collect::<String>(),
            "abb.+.a."
        );
    }

    #[test]
    fn post2nfa_test() {
        let start = post2nfa(vec!['a']);
        assert!(matches!(start.kind, StateKind::Lit('a')));
        let out = Rc::clone(&start.out);
        assert!(matches!(
            out.borrow_mut().as_ref().unwrap().kind,
            StateKind::Match
        ));

        let start = post2nfa(vec!['a', 'b', '.']);
        assert!(matches!(start.kind, StateKind::Lit('a')));
        let out = Rc::clone(&start.out);
        assert!(matches!(
            out.borrow_mut().as_ref().unwrap().kind,
            StateKind::Lit('b')
        ));

        let start = post2nfa(vec!['a', 'b', '|']);
        assert!(matches!(start.kind, StateKind::Split));
        let out = Rc::clone(&start.out);
        assert!(matches!(
            out.borrow_mut().as_ref().unwrap().kind,
            StateKind::Lit('a')
        ));
        let out1 = Rc::clone(&start.out1);
        assert!(matches!(
            out1.borrow_mut().as_ref().unwrap().kind,
            StateKind::Lit('b')
        ));

        let start = post2nfa(vec!['a', '?']);
        assert!(matches!(start.kind, StateKind::Split));
        let out = Rc::clone(&start.out);
        assert!(matches!(
            out.borrow_mut().as_ref().unwrap().kind,
            StateKind::Lit('a')
        ));
        let out1 = Rc::clone(&start.out1);
        assert!(matches!(
            out1.borrow_mut().as_ref().unwrap().kind,
            StateKind::Match
        ));

        let start = post2nfa(vec!['a', '*']);
        assert!(matches!(start.kind, StateKind::Split));
        let out = Rc::clone(&start.out);
        assert!(matches!(
            out.borrow_mut().as_ref().unwrap().kind,
            StateKind::Lit('a')
        ));
        assert!(matches!(
            out.borrow_mut()
                .as_ref()
                .unwrap()
                .out
                .borrow_mut()
                .as_ref()
                .unwrap()
                .kind,
            StateKind::Split,
        ));
        assert!(out
            .borrow_mut()
            .as_ref()
            .unwrap()
            .out1
            .borrow_mut()
            .is_none(),);
        let out1 = Rc::clone(&start.out1);
        assert!(matches!(
            out1.borrow_mut().as_ref().unwrap().kind,
            StateKind::Match
        ));

        let start = post2nfa(vec!['a', '+']);
        assert!(matches!(start.kind, StateKind::Lit('a')));
        let out = Rc::clone(&start.out);
        assert!(matches!(
            out.borrow_mut().as_ref().unwrap().kind,
            StateKind::Split,
        ));
        assert!(matches!(
            out.borrow_mut()
                .as_ref()
                .unwrap()
                .out
                .borrow_mut()
                .as_ref()
                .unwrap()
                .kind,
            StateKind::Lit('a'),
        ));
    }

    #[test]
    fn matches_test() {
        let start = post2nfa(re2post(lexer("a")));
        assert!(matches(start, "a"));
        let start = post2nfa(re2post(lexer("a")));
        assert!(!matches(start, "ab"));
        let start = post2nfa(re2post(lexer("a")));
        assert!(!matches(start, "b"));

        let start = post2nfa(re2post(lexer("ab")));
        assert!(matches(start, "ab"));

        let start = post2nfa(re2post(lexer("a|b")));
        assert!(matches(start, "a"));
        let start = post2nfa(re2post(lexer("a|b")));
        assert!(matches(start, "b"));
        let start = post2nfa(re2post(lexer("a|b")));
        assert!(!matches(start, "c"));
        let start = post2nfa(re2post(lexer("abcd|bc?da")));
        assert!(matches(start, "abcd"));
        let start = post2nfa(re2post(lexer("abcd|bc?da")));
        assert!(matches(start, "bcda"));
        let start = post2nfa(re2post(lexer("abcd|bc?da")));
        assert!(matches(start, "bda"));

        let start = post2nfa(re2post(lexer("a?")));
        assert!(matches(start, ""));
        let start = post2nfa(re2post(lexer("a?")));
        assert!(matches(start, "a"));
        let start = post2nfa(re2post(lexer("a?a")));
        assert!(matches(start, "a"));
        let start = post2nfa(re2post(lexer("a?a")));
        assert!(matches(start, "aa"));

        let start = post2nfa(re2post(lexer("a*")));
        assert!(matches(start, ""));
        let start = post2nfa(re2post(lexer("a*")));
        assert!(matches(start, "a"));
        let start = post2nfa(re2post(lexer("a*")));
        assert!(matches(start, "aaaaaaaa"));

        let start = post2nfa(re2post(lexer("a+")));
        assert!(!matches(start, ""));
        let start = post2nfa(re2post(lexer("a+")));
        assert!(matches(start, "a"));
        let start = post2nfa(re2post(lexer("a+")));
        assert!(matches(start, "aaaaaaaa"));

        let start = post2nfa(re2post(lexer("a?a?a?aaa")));
        assert!(matches(start, "aaa"));
        let start = post2nfa(re2post(lexer("a?a?a?aaa")));
        assert!(matches(start, "aaaaaa"));
    }
}

参考