본문 바로가기

프로그래밍

[42서울] ft_containers[3] - 트리 구현

서론

원래 rbtree를 이용해서 구현하고 싶었는데, 제대로 이해하지 못한 상태로 대충 구현하고 나서 테스트를 해보니 map의 테스트에서 일부분 문제가 생겼다.(어느 부분인지는 map에서 설명)

 

내가 그걸 못고치는걸 보면서 이건 내가 이해한게 아니라고 생각해서 다른 bst를 찾아봤는데, splay tree라고 하는 간단하게 밸런싱을 하는 트리를 발견해서 그걸로 구현했더니 역시 문제가 생겼다. splay tree는 선형 체인에서 O(n)이어서 30만 개 선형 체인 테스트를 std::map보다 20배 이상 느려져서 통과하지 못했다. 

 

그래서 그냥 힙스터 기질은 접고 다른 사람들처럼 avl tree로 구현하게 되었다.

 

https://www.cs.usfca.edu/~galles/visualization/AVLtree.html

 

AVL Tree Visualzation

 

www.cs.usfca.edu

이걸 보면 직관적으로 이해할수있다. 

 

avl tree란 제작자의 이름을 따서 만들어진 균형 이진 탐색 트리의 일종이다. 다양한 데이터가 입력되었을 때 트리가 한쪽으로 치우치는 것을 막기 위해 매번의 삽입, 삭제마다 왼쪽 서브트리와 오른쪽 서브트리의 높이가 1 이상 차이 나지 않도록 점검해서 트리를 수정하는 방식이다.

 

avl 트리는 bst 와 모든 부분에서 동일하지만, 회전이 필요하고, 삽입과 삭제마다 높이를 확인하는 동작이 추가된다. 당연히 이런 동작이 많아지니까 cpu 오버헤드도 늘어나지만, 최악의 경우에도 O(log n)을 유지하기 때문에 균형 이진 탐색 트리의 대표 격으로 불린다.

 

삽입과 삭제 후에 높이 차이가 2 이상 나는 노드가 있을경우 그 노드를 중심으로 회전을 진행하게 되는데 4가지 경우(LL, RR, LR, RL)로 분류해서 진행한다.

 

구현

template< class Key, class T, class Compare, class Allocator >
class avl_node
{
	public:
		typedef avl_node<Key, T, Compare, Allocator>	node;
		typedef node*									node_ptr;
		typedef Key										key_type;
		typedef T										value_type;
		typedef ft::pair<Key, T>						pair_type;
		typedef size_t									size_type;
		typedef ssize_t									balance_type;
		typedef Allocator								allocator_type;

	private:
		typedef typename allocator_type::template rebind<node>::other node_allocator;
		// 내부적으로 avl_node<Key, T, Compare, Allocator> 만 전담하는 allocator

		pair_type	_data;
		size_type	_height;
		node_ptr 	left;
		node_ptr 	right;
		Compare		key_compare;
}

이런 변수들을 사용하게 되는데, bst로 구현을 했어도 변수는 여기서 크게 달라지지 않았을 것이다. 데이터, 왼쪽노드, 오른쪽노드, compare는 모든 tree에서 당연히 들어가는 부분이고, 여기에 height가 추가되면 avl tree가 된다.

 

그 외에 find, upper_bound, lower_bound, next(successor), prev(predecessor) 같은 함수들은 동일한데, insert / delete에서는 차이가 있다.

 

insert

insert도 bst처럼 해당 값이 들어가야하는 위치를 찾아서 입력하는 건 동일하지만, 삽입 이후에 균형을 계산해서 그 균형에 맞게 추가로 동작을 해줘야 한다.

 

우선 bst의 삽입을 보자(from. chatGPT)

반복문을 이용해서 해당 노드가 들어갈 자리를 찾고 삽입을 한다.

 

균형을 맞추기 위한 동작이 아예 없기 때문에, 1, 2, 3, 4, 5, 이런 순서로 계속 들어오면 선형 체인(linear chain)이 되어서 이후에 탐색, 삽입, 삭제의 시간복잡도에 문제가 생긴다.

 

그래서 균형을 맞추기 위한 변수와 동작이 추가된 avl tree는 다음과 같다.

static node_ptr insert(node_ptr node, pair_type data, node_allocator& alloc)
{
	// 노드가 없을경우 생성
	if (node == NULL)
		return(new_node(data, alloc));
	// data 의 key 가 node 의 키보다 작으면 왼쪽으로 이동
	if (data.first < node->key()) 
		node->left = insert(node->left, data, alloc);
	//  data 의 key 가 node 의 키보다 크면 오른쪽으로 이동
	else if (data.first > node->key()) 
		node->right = insert(node->right, data, alloc);
	// 키가 중복될수없기때문에 바로 반환
	else
		return node;

	node->_height = 1 + std::max(height(node->left), height(node->right));

	// 균형을 계산해서 균형이 깨진 경우에 LL RR LR RL 동작을 수행
	int bal = balance(node);

	// 루트의 왼쪽 자식의 왼쪽 서브트리에 노드를 삽입한 경우(LL), 오른쪽으로 회전
	if (bal > 1 && data.first < node->left->key())
		return right_rotate(node);
	// 루트의 오른쪽 자식의 오른쪽 서브트리에 노드를 삽입한 경우(RR), 왼쪽으로 회전
	if (bal < -1 && data.first > node->right->key()) 
		return left_rotate(node);
	// 루트의 왼쪽 자식의 오른쪽 서브트리에 노드를 삽입한 경우(LR), 왼쪽 자식을 왼쪽으로 회전한 후, 루트를 오른쪽으로 회전
	if (bal > 1 && data.first > node->left->key()) 
	{
		node->left = left_rotate(node->left);
		return right_rotate(node);
	}
	// 루트의 오른쪽 자식의 왼쪽 서브트리에 노드를 삽입한 경우(RL), 오른쪽 자식을 오른쪽으로 회전한 후, 루트를 왼쪽으로 회전
	if (bal < -1 && data.first < node->right->key()) 
	{
		node->right = right_rotate(node->right);
		return left_rotate(node);
	}

	return node;
}

반복문이 재귀로 바뀌어서 재귀를 통해 해당 key의 위치를 찾아서 삽입한 다음, 재귀가 풀리면서 해당 위치를 찾기 위해 지나간 모든 경로에서 균형을 확인한 다음 적절하게 회전을 해주도록 구현했다.

 

delete

delete 도 insert와 비슷한 느낌으로, 해당 노드를 찾아서 삭제한 다음 삭제한 노드부터 재귀가 풀리면서 균형을 확인하고 적절하게 회전을 해준다.

// 수정된 서브트리의 루트를 반환
static node_ptr del_node(node_ptr root, key_type key, node_allocator& alloc)
{
	if (root == NULL)
		return root;

	// 해당 key를 찾기위한 재귀
	if (key < root->key())
		root->left = del_node(root->left, key, alloc);
	else if (key > root->key()) 
		root->right = del_node(root->right, key, alloc);
	// 해당 key 를 찾았을때 / 혹은 찾고자 하는 key가 없을때
	else 
	{
		if ((root->left == NULL) || (root->right == NULL)) 
		{
			node_ptr temp = root->left ? root->left : root->right;

			if (temp == NULL) 
			{
				temp = root;
				root = NULL;
			}
			else
				*root = *temp;
			alloc.destroy(temp);
			alloc.deallocate(temp, 1);
		}
		else // 노드에 자식이 둘일때 오른쪽에서 가장 작은 값을 찾아서 그것과 현재 key를 변경하고, 그 다음에 그 노드를 삭제한다.
		{
			node_ptr temp = min(root->right);
			root->_data = temp->_data;

			root->right = del_node(root->right, temp->key(), alloc);
		}
	}

	if (root == NULL) 
		return root;

	root->_height =	std::max(height(root->left), height(root->right)) + 1;

	int bal = balance(root);

	// Left Left
	if (bal > 1 && balance(root->left) >= 0) 
		return right_rotate(root);
	// Left Right
	if (bal > 1 && balance(root->left) < 0)
	{
		root->left = left_rotate(root->left);
		return right_rotate(root);
	}
	// Right Right
	if (bal < -1 && balance(root->right) <= 0)
		return left_rotate(root);

	// Right Left
	if (bal < -1 && balance(root->right) > 0) 
	{
		root->right = right_rotate(root->right);
		return left_rotate(root);
	}

	return root;
}