Skip to content
Permalink
main
Switch branches/tags

Name already in use

A tag already exists with the provided branch name. Many Git commands accept both tag and branch names, so creating this branch may cause unexpected behavior. Are you sure you want to create this branch?
Go to file
 
 
Cannot retrieve contributors at this time
""" Basic BST code for inserting (i.e. building) and printing a tree
Your ***second standard viva task*** (of 5) will be to implement a find method into
the class BinaryTree from pseudocode. See the lab task sheet for Week 5.
Your ***first advanced viva task*** (of 3) will be to implement a remove (delete) method
into the class Binary Tree from partial pseudocode. See the lab task sheet for Week 5 (available in Week 5).
There will be some ***introductory challenges*** in Week 4, with solutions released in Week 5.
It is STRONGLY RECOMMENDED you attempt these!
Since the given code is in python it is strongly suggested you stay with python; but
if you want to reimplement as C++ this is also OK (see the Week 5 lab sheet guidance).
"""
import math
""" Node class
"""
class Node:
def __init__(self, data = None):
self.data = data
self.left = None
self.right = None
""" BST class with insert and display methods. display pretty prints the tree
"""
class BinaryTree:
def __init__(self):
self.root = None
##############################
## def find_i(self, data): #pass in values defined elsewhere
## print("iteratively searching for: ", data) #tell the user what the function is doing
## cur_node = self.root #set the current node to be equal to the root value
## while cur_node != None: #loop while there is a node to check
## if cur_node.data == data: #if the current node is equal to the value being searched for
## print("Value found")
## return True #target value is found, return True
## elif cur_node.data > data: #if the current node is greater than the target value
## cur_node = cur_node.left #set the current node as the left node
## print("Current node is greater than target, moving to the left child")
## else: #if the current node is smaller than the target
## cur_node = cur_node.right #set the current node as the right node
## print("Current node is less than the target, moving to the right child")
## return False #target value isnt found, return false, this whole section is looped until the value is found
##
#######
##
## def find_r(self, data): #pass in values defined elsewhere
## print("recursively searching for: ", data) #tell the user what the function is doing
## return self._find_r(data, self.root) #pass data to the main search function
##
## def _find_r(self, data, cur_node):
## if cur_node.data == data: #if the current node is equal to the value being searched for
## print("Value found")
## return True #target value is found, return True
## elif cur_node.data > data: #if the current node is greater than the target value
## print("Current node is greater than target, moving to the left child")
## return self._find_r(data,cur_node.left) #recursively call the function
## else: #if the current node is smaller than target
## print("Current node is less than the target, moving to the right child")
## return self._find_r(data,cur_node.right) #recursively call the function
## return False #target value isnt found
##############################
def insert(self, data):
if self.root is None:
self.root = Node(data)
else:
self._insert(data, self.root)
def _insert(self, data, cur_node):
if data < cur_node.data:
if cur_node.left is None:
cur_node.left = Node(data)
else:
self._insert(data, cur_node.left)
elif data > cur_node.data:
if cur_node.right is None:
cur_node.right = Node(data)
else:
self._insert(data, cur_node.right)
else:
print("Value already present in tree")
#v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v v
def remove(self,data): #function called
print("removing: ", data) #print target so it's clear what is happening
if self.root == None: #//if no tree
return False #send the boolean 'false' as the result of the function
else:
self._remove(data,self.root) #if there is a tree, call main remove function
def _remove(self,data,cur_node):
if self.root.data == data: #//if tree root is target
if self.root.left.data == None and self.root.right.data == None: #if the left child and right nodes dont exist
self.root = None #delete the root object
elif cur_node == None: #//CASE 1: Target not found
return False #//for info only (we could not find it)
while cur_node != data: #while the current node is not equal to the target, repeat
if cur_node.left == None and cur_node.right == None: #//CASE 2: Target has no children
if cur_node == self.root: #if the node is equal to the root of the tree
self.root = None #delete the root object
cur_node = None #delete the current node object
print("leaf node deleted")
return None #send None datatype as the result of the function
elif cur_node.right == None: #//CASE 3: target has left child only
if cur_node == self.root: #if the node is equal to the root of the tree
self.root = self.root.left #set the root as the left child value
temp_node = self.root.left #create a temporary node that uses the left child value
self.root = None #delete the root object
print("left node deleted")
return temp_node #//info only
elif cur_node.left == None: #//CASE 4: target has right child only
if cur_node == self.root: #if the node is equal to the root of the tree
self.root = self.root.right #set the root as the right child value
temp_node = self.root.right #create a temporary node that uses the right child value
self.root = None #delete the root object
print("right node deleted")
return temp_node #//info only
elif data < cur_node.data: #if the target is less than the value of the current node
print("target is smaller than current node")
cur_node.left = self._remove(data, cur_node.left) #recursively call the function and set the result as the left child node
elif data > cur_node.data: #if the target is greater than the value of the current node
print("target is greater than current node")
cur_node.right = self._remove(data, cur_node.right) #recursively call the function and set the result as the right child node
else: #//CASE 5: target has left and right children
while cur_node == None: #if the current node doesn't exist
return self._remove(cur_node) #call the function again
else: #if there is a current node
print("parent node created")
temp_node = cur_node.left #create a temporary node with the left child node as its value
cur_node.data = temp_node.data #set value of the current node to the value of the temporary node just created
cur_node.left = self._remove(cur_node.data,cur_node.left) #recursively call the function and set the result as the left child node
return cur_node #send temporary node as the result of the function
#^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^ ^
def display(self, cur_node):
lines, _, _, _ = self._display(cur_node)
for line in lines:
print(line)
def _display(self, cur_node):
if cur_node.right is None and cur_node.left is None:
line = '%s' % cur_node.data
width = len(line)
height = 1
middle = width // 2
return [line], width, height, middle
if cur_node.right is None:
lines, n, p, x = self._display(cur_node.left)
s = '%s' % cur_node.data
u = len(s)
first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s
second_line = x * ' ' + '/' + (n - x - 1 + u) * ' '
shifted_lines = [line + u * ' ' for line in lines]
return [first_line, second_line] + shifted_lines, n + u, p + 2, n + u // 2
if cur_node.left is None:
lines, n, p, x = self._display(cur_node.right)
s = '%s' % cur_node.data
u = len(s)
first_line = s + x * '_' + (n - x) * ' '
second_line = (u + x) * ' ' + '\\' + (n - x - 1) * ' '
shifted_lines = [u * ' ' + line for line in lines]
return [first_line, second_line] + shifted_lines, n + u, p + 2, u // 2
left, n, p, x = self._display(cur_node.left)
right, m, q, y = self._display(cur_node.right)
s = '%s' % cur_node.data
u = len(s)
first_line = (x + 1) * ' ' + (n - x - 1) * '_' + s + y * '_' + (m - y) * ' '
second_line = x * ' ' + '/' + (n - x - 1 + u + y) * ' ' + '\\' + (m - y - 1) * ' '
if p < q:
left += [n * ' '] * (q - p)
elif q < p:
right += [m * ' '] * (p - q)
zipped_lines = zip(left, right)
lines = [first_line, second_line] + [a + u * ' ' + b for a, b in zipped_lines]
return lines, n + m + u, max(p, q) + 2, n + u // 2
#example calls, which construct and display the tree
bst = BinaryTree()
bst.insert(4)
bst.insert(2)
bst.insert(6)
bst.insert(1)
bst.insert(3)
bst.insert(5)
bst.insert(7)
bst.insert(100)
bst.insert(33)
bst.insert(200)
# v v v v v v v
bst.remove(6)
# ^ ^ ^ ^ ^ ^ ^
#bst.insert(6)
bst.display(bst.root)