Python) AVLTree Search를 통해서 특정 값의 범위 찾기 (ChatGPT와 함께)

2023. 3. 22. 22:43분석 Python/구현 및 자료

FastAVLTree는 AVL 트리 데이터 구조의 구현을 제공하는 파이썬 라이브러리입니다. AVL 트리는 자체 균형 이진 검색 트리의 일종으로, Georgy Adelson-Velsky와 Evgenii Landis에 의해 1962년 처음 소개되었으며, 이들의 이름을 따서 명명되었습니다.

AVL 트리는 모든 노드의 왼쪽과 오른쪽 서브트리의 높이 차가 최대 하나 이하인 이진 검색 트리입니다. 이는 트리가 항상 균형을 유지하므로 트리의 높이가 노드 수에 대해 로그 적이라는 것을 보장합니다. 이 속성은 AVL 트리가 검색, 삽입 및 삭제 작업을 빠르게 수행할 수 있도록 보장하며, 최악의 경우 시간 복잡도는 노드 수에 대해 O(log n)입니다.

FastAVLTree는 Python에서 AVL 트리의 빠르고 메모리 효율적인 구현을 제공하기 위해 설계되었으며, 성능을 향상하기 위한 여러 가지 최적화를 포함합니다. 이러한 최적화에는 메모리 오버헤드를 줄이기 위해 저수준 메모리 관리 기능을 사용하고, 인라이닝 및 루프 언론링을 사용하여 함수 호출 오버헤드를 줄이고, 비트 조작을 사용하여 코드에서 불필요한 분기를 피하는 것이 포함됩니다.

전반적으로, FastAVLTree는 데이터의 효율적인 저장 및 검색이 필요한 애플리케이션에 유용한 라이브러리이며, 데이터베이스, 검색 엔진 및 머신러닝 알고리즘의 데이터 구조를 구현하는 데 사용될 수 있습니다.

 

 

모든 노드의 왼쪽 및 오른쪽 하위 트리의 높이는 최대 1만큼 차이가 나며, 이는 높이 간의 차이가 항상 1,0 또는 -1임을 의미합니다

 

 

그림 (a)의 트리는 AVL 트리입니다. 보시다시피 모든 노드에 대해 왼쪽 하위 트리와 오른쪽 하위 트리 사이의 높이 차이는 mod(1)을 초과하지 않습니다.

그림 (b)에서 값이 3인 노드를 확인하면 왼쪽 하위 트리는 height=2이고 오른쪽 하위 트리는 height=0입니다. 따라서 차이는 mod(2-0) = 2입니다. 따라서 AVL 속성을 만족하지 않으며 AVL 트리가 아닙니다.

 

 

 

위의 예에서 모든 노드의 균형 요소는 -1과 +1 사이입니다. 따라서 AVL 트리입니다.

 

가장 이해하기 쉬웠던 동영상이었다.

 

https://www.youtube.com/watch?v=mGF61O21W-c&ab_channel=AlexTemnok 

비교

 

사용

AVL 트리는 자체 균형 이진 검색 트리에서 빠르고 효율적인 노드 검색, 삽입 및 삭제가 필요할 때 일반적으로 사용됩니다. 특히 AVL 트리는 다음과 같은 상황에서 유용합니다.

 

  • 데이터 세트는 동적이고 끊임없이 변화합니다.
  • 데이터 세트가 크고 효율적인 검색, 삽입 및 삭제 작업이 필요합니다.
  • 데이터 세트는 수명 기간 동안 삽입 및 삭제 비율이 높을 것으로 예상됩니다.

AVL 트리는 시간이 지남에 따라 데이터 세트가 변경되더라도 균형 잡힌 구조를 유지할 수 있으므로 삽입 및 삭제 횟수를 미리 알 수 없는 경우에 특히 효과적입니다. 이렇게 하면 트리가 효율적으로 유지되고 검색 및 가져오기 작업을 위한 빠른 액세스 시간이 제공됩니다.

 

 

예시 1

해당 코드는 chatgpt의 도움을 받아 빠르게 구성해 봤다.

좀 수정하면서 해도 금방 할 수 있어서 정말 편하다는 것을 느꼈다.

본 글에서는 특정 함수들의 y를 알 경우에 x의 범위를 찾는 코드를 보고자 한다.

 

import matplotlib.pyplot as plt
from bintrees import FastAVLTree

# Define the range of x values to search over
x_min = 0
x_max = 10

# Define the tolerance for the search
tol = 1e-3

import numpy as np
from sortedcontainers import SortedDict

# Define the step function
step = lambda x, t: 1 if x >= t else 0

# Define the mixed step function for f1 using only linear and step functions
f1 = lambda x: 1.0 * step(x, 2) + (x - 2) * step(x, 2) * (1 - step(x, 4)) + \
               (4 - x) * step(x, 4) * (1 - step(x, 6)) + 1.0 * step(x, 6)

# Define the mixed step function for f2 using only linear and step functions
f2 = lambda x: 1.0 * step(x, 1) + (x - 1) * step(x, 1) * (1 - step(x, 3)) + \
               (3 - x) * step(x, 3) * (1 - step(x, 5)) + 1.0 * step(x, 5)

# Define the mixed step function for f3 using only linear and step functions
f3 = lambda x: 0.5 * (1 + step(x, 0)) + (x - 1) * step(x, 1) * (1 - step(x, 2)) + \
               (2 - x) * step(x, 2) * (1 - step(x, 3)) + (3 - x) * step(x, 3) * (1 - step(x, 4)) + \
               (4 - x) * step(x, 4) * (1 - step(x, 5)) + (5 - x) * step(x, 5) * (1 - step(x, 6)) + \
               0.5 * (1 - step(x, 6))

# Define the mixed step function for f4 using only linear and step functions
f4 = lambda x: 1.0 * step(x, 1) + (x - 1) * step(x, 1) * (1 - step(x, 3)) + \
               (3 - x) * step(x, 3) * (1 - step(x, 5)) + 1.0 * step(x, 5)

# Define the mixed step function for f5 using only linear and step functions
f5 = lambda x: 0.5 * (1 + step(x, 0)) + (x - 1) * step(x, 1) * (1 - step(x, 3)) + \
               (3 - x) * step(x, 3) * (1 - step(x, 5)) + (5 - x) * step(x, 5) * (1 - step(x, 7)) + \
               0.5 * (1 - step(x, 7))

# Define the mixed step function for f6 using only linear and step functions
f6 = lambda x: 1.0 * step(x, 2) + (x - 2) * step(x, 2) * (1 - step(x, 4)) + \
               (4 - x) * step(x, 4) * (1 - step(x, 6)) + 1.0 * step(x, 6)

x_min = 0
x_max = 10

input_functions = [f1,f2,f3,f4,f5,f6]
x_list = []
y_result = []
for x in np.linspace(int(x_min),int(x_max),num=100):
    x_list.append(x)
    y_result.append([f(x) for f in input_functions])
    
plt.figure(figsize=(15,6))
for i in range(np.array(y_result).shape[1]) :
    plt.plot(x_list , np.array(y_result)[:,i],label=f'f{i+1}')
plt.legend()
plt.show()

 

아래 그림처럼 전체 범위에서 6개가 다 만족하는 부분은 저 끝부분인데 저걸 찾고자 한다.

이때 AVL Tree 코드를 사용해서 빠르게 찾는 것을 진행해 봤다.

 

# Define a function to check if the functions are satisfied for a given x value
def check_functions(x, ys):
    fxs = tree[x]
    return all([np.abs(fx - y) < tol for fx, y in zip(fxs, ys)])

# Define a function to find the range of x values that satisfy all input functions
def find_x_range(*ys):
    keys = list(tree.keys())
    left = 0
    right = len(keys) - 1

    # Find the estimated index of the x value that satisfies all input functions using interpolation search
    while left <= right and not check_functions(keys[left], ys) and not check_functions(keys[right], ys):
        mid = int(left + (right - left) * (ys[0] - tree[keys[left]][0]) / (tree[keys[right]][0] - tree[keys[left]][0]))
        if check_functions(keys[mid], ys):
            right = mid
            left = mid
        elif tree[keys[mid]][0] < ys[0]:
            left = mid + 1
        else:
            right = mid - 1

    # Find the lower bound of the range of x values that satisfies all input functions
    if check_functions(keys[left], ys):
        lower_bound = left
    else:
        lower_bound = right

    left = lower_bound
    right = len(keys) - 1

    # Find the estimated index of the x value that satisfies all input functions using interpolation search
    while left <= right and not check_functions(keys[left], ys) and not check_functions(keys[right], ys):
        mid = int(left + (right - left) * (ys[0] - tree[keys[left]][0]) / (tree[keys[right]][0] - tree[keys[left]][0]))
        if check_functions(keys[mid], ys):
            left = mid
            right = mid
        elif tree[keys[mid]][0] < ys[0]:
            left = mid + 1
        else:
            right = mid - 1

    # Find the upper bound of the range of x values that satisfies all input functions
    if check_functions(keys[right], ys):
        upper_bound = right
    else:
        upper_bound = left

    # The range of x values that satisfies all input functions is given by [keys[lower_bound], keys[upper_bound]]
    return [keys[lower_bound], keys[upper_bound]]

# Define a function to find all ranges of x values that satisfy all input functions
def find_all_x_ranges(*ys):
    x_ranges = []
    current_range = []
    for x in np.linspace(x_min, x_max, num=int((x_max - x_min) / tol)):
        if check_functions(x, ys):
            if not current_range:
                current_range.append(x)
        else:
            if current_range:
                current_range.append(x)
                x_ranges.append(current_range)
                current_range = []
    if current_range:
        current_range.append(x_max)
        x_ranges.append(current_range)
    return x_ranges

y1 = 2.0
y2 = 2.0
y3 = 1.0
y4 = 2.0
y5 = 1.0
y6 = 2.0 

# Define the list of input functions
input_functions = [f1, f2, f3,f4,f5,f6] # Add any additional functions here

# Create an AVL tree to store the values of the functions
tree = FastAVLTree()

# Compute the values of the functions for each x value in the range and insert them into the tree
for x in np.linspace(x_min, x_max, num=int((x_max - x_min) / tol)):
    tree[x] = [f(x) for f in input_functions]


# Find the range of x values that satisfy f1(x) = 1, f2(x) = 0.5, and f3(x) = -0.5
x_range = find_all_x_ranges(y1,y2,y3,y4,y5,y6)
print(x_range)

대단함을 다시 한번 느낀다...

[[7.000700070007, 10]]

 

예시 2

 

일반화된 코드로 다른 예시로 진행해 봤다.

가장 좋았던 점은 예시도 잘 만들어줘서 고민하는 시간을 많이 줄였다.

# Define a step function for a function with repeating y values
step = lambda x, t: 1 if x >= t else 0
def step_fn1(x):
    return 1 + step(x, 2) * (2 - 1) + step(x, 4) * (1 - 2) + step(x, 6) * (2 - 1)

# Define the second repeating step function with y values of 1 and 2 only
def step_fn2(x):
    return 2 - step(x, 2) * (2 - 1) - step(x, 8) * (2 - 1)

# Define the third repeating step function with y values of 1 and 2 only
def step_fn3(x):
    return 1 + step(x, 2) * (2 - 1) - step(x, 4) * (1 - 2) - step(x, 6) * (2 - 1)

x_min = 0
x_max = 10

input_functions = [step_fn1,step_fn2,step_fn3]
x_list = []
y_result = []
for x in np.linspace(int(x_min),int(x_max),num=100):
    x_list.append(x)
    y_result.append([f(x) for f in input_functions])

for i in range(np.array(y_result).shape[1]) :
    plt.plot(x_list , np.array(y_result)[:,i],label=f'f{i+1}')
plt.legend()
plt.show()

저기 겹치는 부분을 찾게 하는 작업을 수행해 봤다.

y1 = 2.0
y2 = 1.0
y3 = 2.0

# Define the list of input functions
input_functions = [step_fn1,step_fn2,step_fn3] # Add any additional functions here

# Create an AVL tree to store the values of the functions
tree = FastAVLTree()

# Compute the values of the functions for each x value in the range and insert them into the tree
for x in np.linspace(x_min, x_max, num=int((x_max - x_min) / tol)):
    tree[x] = [f(x) for f in input_functions]


# Find the range of x values that satisfy f1(x) = 1, f2(x) = 0.5, and f3(x) = -0.5
x_range = find_all_x_ranges(y1,y2,y3)
print(x_range)

너무 잘 찾는 것 같다.


[[2.000200020002, 4.000400040004], [6.0006000600060005, 8.000800080008]]

 

엄청나게 고민하면서 할 일을 금방 끝낸 기분이다

실제로 써보니 대단하면서도 내 일자리에 위협을 줄 수 있다는 생각에 걱정이 된다 ㅠㅠㅠㅠ

 

참고

https://www.codingninjas.com/codestudio/library/introduction-to-avl-trees

 

Introduction To AVL Trees

This article discusses the introduction to AVL trees, their advantages, representation, properties and various operations performed on them.

www.codingninjas.com

https://code-lab1.tistory.com/61

 

[자료구조] AVL트리란? AVL트리 쉽게 이해하기, AVL트리 시뮬레이터

AVL 트리란? 예전에 이진탐색트리에 대해 알아본적이 있다. [자료구조] 이진탐색트리(Binary Search Tree)의 개념, 이해 | C언어 이진탐색트리 구현 이진탐색트리(Binary Search Tree)이란? 이진탐색트리란

code-lab1.tistory.com

 

728x90