본문 바로가기

컴퓨터/파이썬 공부정리

[Python] avl 트리 구현

class BSTNode:
    def __init__(self, key, value):
        self.key = key
        self.value = value
        self.left = None
        self.right = None

def search_bst(n, key):
    if n is None:
        return None
    elif n.key == key:
        return n
    elif n.key < key:
        return search_bst(n.right, key)
    else:
        return search_bst(n.left, key)

def search_bst_iter(n, key):
    while n is not None:
        if n.key == key:
            return n
        elif key < n.key:
            n = n.left
        else:
            n = n.right
    return None

def search_value_bst(n, value):
    if n is None:
        return None

    if n.value == value:
        return n
    res = search_value_bst(n.left, value)
    if not res:
        return res
    else:
        return search_value_bst(n.right, value)

def search_bst_max(n):
    while n is not None and n.right is not None:
        n = n.right
    return n

def search_bst_min(n):
    while n is not None and n.left is not None:
        n = n.left
    return n

def insert_bst(r, n):
    if n.key < r.key:
        if r.left is None:
            r.left = n
            return True
        else:
            return insert_bst(r.left, n)
    elif n.key > r.key:
        if r.right is None:
            r.right = n
            return True
        else:
            return insert_bst(r.right, n)
    else:
        return False

# 단말 노트의 삭제
def delete_bst_case1(parent, node, root):
    if parent is None:
        root = None
    else:
        if parent.right is node:
            parent.right = None
        else:
            parent.left = None

    return root

# 자식이 한 개 있는 노드의 삭제
def delete_bst_case2(parent, node, root):
    if node.left is not None:
        child = node.left
    else:
        child = node.right

    if root == node:
        root = child
    else:
        if parent.left is node:
            parent.left = child
        else:
            parent.right = child
    
    return root

# 자식이 두 개 있는 노드의 삭제
def delete_bst_case3(parent, node, root):
    succp = node
    succ = node.right
    while succ.left is not None:
        succp = succ
        succ = succ.left

    if succp.left is succ:
        succp.left = succ.right
    else:
        succp.right = succ.right
    
    node.key = succ.key
    node.value = succ.value
    node = succ

# 모든 경우에 대한 삭제 연산
from typing import *
def delete_bst(root : BSTNode, key):
    if root == None:
        return None
    
    parent = None
    node = root
    while node is not None and node.key != key:
        parent = node
        if node.key < key:
            node = node.right
        else:
            node = node.left

    if node is None:
        return None
    elif node.left is None and node.right is None:
        root = delete_bst_case1(parent, node, root)
    elif node.left is None or node.right is None:
        root = delete_bst_case2(parent, node, root)
    else:
        root = delete_bst_case3(parent, node, root)

    return root
    
def display(n):
    if n is None:
        return

    display(n.left)
    print(n.key)
    display(n.right)
    
class BSTMap:
    def __init__(self):
        self.root = None
    
    def isEmpty(self):
        return self.root == None
    
    def clear(self):
        self.root = None
    
    def size(self):
        def count_node():
            node = self.root
            if node is None:
                return 0
            stack = []
            stack.append(node)
            cnt = 0
            while stack:
                cnt += 1
                if node.left is not None:
                    stack.append(node.left)
                if node.right is not None:
                    stack.append(node.right)
            return cnt
        return count_node()

    def search(self, key):
        return search_bst(self.root, key)
    
    def searchValue(self, value):
        return search_value_bst(self.root, value)
    
    def findMax(self):
        return search_bst_max(self.root)

    def findMin(self):
        return search_bst_min(self.root)

    def insert(self, key, value=None):
        n = BSTNode(key, value)
        if self.isEmpty():
            self.root = n
        else:
            insert_bst(self.root, n)

    def delete(self, key):
        self.root = delete_bst(self.root, key)

    def display(self, msg="BSTMap :"):
        def inorder(n):
            if n.left:
                inorder(n.left)
            print(n.key, end=' ')
            if n.right:
                inorder(n.right)
        
        print(msg, end='')
        inorder(self.root)
        print()
####
map = BSTMap()
data = [35, 18, 7, 26, 12, 3, 68, 22, 30, 99]
print("[삽입 연산] : ", data)
for key in data:
    map.insert(key)
map.display("[중위 순회] : ")

####
def rotateLL(A):
    #print(A.key, "asd")
    B = A.left
    A.left = B.right
    B.right = A
    return B

def rotateRR(A):
    #print(A.key, "asd")
    B = A.right
    A.right = B.left
    B.left = A
    return B

def rotateRL(A):
    #print(A.key, "asd")
    B = A.right
    A.right = rotateLL(B)
    return rotateRR(A)

def rotateLR(A):
    #print(A.key, "asd")
    B = A.left
    A.left = rotateRR(B)
    return rotateLL(A)

def calc_height_diff(root : BSTNode):
    if root is None or (root.left is None and root.right is None):
        return 0

    if root.left and root.right:
        queue = deque([root.left])
        left_depth = 0
        while queue:
            left_depth += 1
            for _ in range(len(queue)):
                cur_root = queue.popleft()
                if cur_root.left:
                    queue.append(cur_root.left)
                if cur_root.right:
                    queue.append(cur_root.right)

        queue = deque([root.right])
        right_depth = 0
        while queue:
            right_depth += 1
            for _ in range(len(queue)):
                cur_root = queue.popleft()
                if cur_root.left:
                    queue.append(cur_root.left)
                if cur_root.right:
                    queue.append(cur_root.right)

        return left_depth - right_depth

    elif root.left:
        return calc_height_diff(root.left) + 1
    else:
        return calc_height_diff(root.right) - 1
    
def reBalance(parent):
    hDiff = calc_height_diff(parent)
    if hDiff > 1:
        if calc_height_diff(parent.left) > 0:
            parent = rotateLL(parent)
        else:
            parent = rotateLR(parent)
    elif hDiff < -1:
        if calc_height_diff(parent.right) < 0:
            parent = rotateRR(parent)
        else:
            parent = rotateRL(parent)
    return parent

def insert_avl(parent, node):
    if node.key < parent.key:
        if parent.left is not None:
            parent.left = insert_avl(parent.left, node)
        else:
            parent.left = node
        return reBalance(parent)

    elif node.key > parent.key:
        if parent.right is not None:
            parent.right = insert_avl(parent.right, node)
        else:
            parent.right = node
        return reBalance(parent)
    else:
        print("중복된 키 에러")

class AVLMap(BSTMap):
    def __init__(self):
        super().__init__()
    
    def insert(self, key, value=None):
        n = BSTNode(key, value)
        if self.isEmpty():
            self.root = n
        else:
            self.root = insert_avl(self.root, n)

    def display(self, msg = "AvlMap :"):
        print(msg, end=' ')
        
        que = deque()
        que.append(self.root)
        while len(que) > 0:
            node = que.popleft()
            print(f" {node.key}({calc_height_diff(node)})", end=' ')
            if node.left is not None:
                que.append(node.left)
            if node.right is not None:
                que.append(node.right)
        print()

node = [7,8,9,2,1,5,3,6,4]
map = AVLMap()

for i in node:
    map.insert(i)
    map.display("AVL(%d): " %i)

5/23 수정

key가 정수인 경우 정상적으로 동작하지만 숫자가 아닌 문자열의 경우 제대로 동작하지 않음.

avl tree를 만드는데 있어 회전을 하려고 하면 B가 None으로 되어 에러 발생.