알고리즘

패턴매칭 - KMP(Knuth–Morris–Pratt algorithm) 알고리즘

NickTop 2024. 4. 16. 00:19

KMP 알고리즘은 텍스트에서 패턴이 몇번째에 있는지 찾는 알고리즘입니다

시간복잡도는 O(N+M) 입니다. (N은 텍스트의 길이, M은 패턴의 길이)

 

Naive algorithm

text 와 찾을 패턴

다음과 같은 순서로 text가 i부터 시작했을 때 pattern과 문자가 일치하는지 찾습니다

1. text[0:4] == pattern

2. text[1:5] == pattern

...

def solution(text, pattern):
    answer = []
    for i in range(len(text)):
        for j in range(len(pattern)):
            if i+j>=len(text) or text[i+j]!=pattern[j]:
                break
            if j==len(pattern)-1:
                answer.append(i)
    return answer
print(solution(text,pattern))

 

하지만 이 과정에서 text의 두번째 index는 a와 동일한지 비교를 3번이나 수행하게 되어 비효율적으로 느껴집니다

naive algorithm 문제점

KMP알고리즘은 이를 해결한 알고리즘입니다.

 

KMP

lps(longest prefix which is also suffix)를 활용하여 최적화합니다

lps[i]=l 이라면, pattern[i+1-l:i+1]==pattern[:l]이라는 뜻입니다 (단, l<i)

예제의 pattern의 lps는 아래와 같습니다

aaab의 lps

 

text와 pattern을 처음부터 비교하면 3번째 요소가 다릅니다

여기서, lps 덕분에 아래 그림에서 파란색 동그라미가 서로 같다는 것을 알고 있습니다

KMP에서 pattern처리

 

이를 활용하여, naive 알고리즘은 text[1]부터 다시 비교를 했겠지만, text 인덱스를 다시 앞으로 돌려 비교할 필요없이 빨간색에서 다음 비교할 대상은 그림의 노란색과 같아집니다

다음 비교 대상

 

text에서 비교하는 element가 뒤로 가지않고 항상 앞으로만 가기 때문에 O(N+M)의 시간복잡도를 가진다는 것을 생각해보실수 있습니다

 

LPS도 naive하게 구하면 앞선 naive algorithm과 동일한 문제점을 겪게 됩니다. 마찬가지로 이전에 활용할 알고리즘을 그대로 쓸 수 있습니다

lps 구하기

 

빨간색 지점에서 최초로 차이가 발생합니다. 파란색 네모가 서로 같다는 것을 알고있기 때문에 다음 비교대상은 노란색 지점입니다. (이전의 직전의 prefix로 구해야 하는 것에 주의)

 

코드

def computeLPS(pattern):
    i=1 # 항상 lps[0]=0
    lps = [0]*len(pattern)
    length = 0
    while i<len(pattern):
        while pattern[i]!=pattern[length] and length>0:
            length = lps[length-1]
        if pattern[i]==pattern[length]:
            length+=1
            lps[i]=length
            i+=1
        else:
            lps[i]=0
            i+=1
    return lps
            
def solution(text, pattern):
    answer = []
    lps = computeLPS(pattern)
    print("lps : ",lps)
    i=0
    length=0
    while i<len(text):
        while text[i]!=pattern[length] and length>0:
            length = lps[length-1]
        if text[i]==pattern[length]:
            length+=1
            i+=1
        else:
            i+=1
        if length==len(pattern):
            answer.append(i-length)
            length = lps[length-1]
    return answer
    
text = "aaaaaaaaab"
pattern = "aaab"
print(solution(text,pattern))
text = "ABABDABACDABABCABAB"
pattern = "ABABCABAB"
print(solution(text,pattern))
text = "aaaaaaaaaaaaaaaaaa"
pattern = "aaa"
print(solution(text,pattern))

inner loop 때문에 O(NM)이라고 생각할수있는데 length는 outer loop 한번 당 1밖에 증가하지 않음에 주의하자

 

틀린부분 있으면 말씀해주세요