Note: I am experimenting with exporting Jupyter notebooks into a WordPress ready format. This notebook refers specifically to to the Nature Conservancy Kaggle, for classifying fish species, based only on photographs.
better living through computation.
Note: I am experimenting with exporting Jupyter notebooks into a WordPress ready format. This notebook refers specifically to to the Nature Conservancy Kaggle, for classifying fish species, based only on photographs.
I’ve been messing around with a few things during my time off for the Holidays:
So here’s a quick combination of these things, in the form of a simple guide to using Keras on the Nature Conservatory image recognition Kaggle.
Hopefully it serves as an easy introduction to get up and running with neural networks for this competition.
For awhile now, the Computer Science department at my University has offered a class for non-CS students called “Data Witchcraft“. The idea, I suppose, is that when you don’t understand how a technology works, it’s essentially “magic“. (And when it has to do with computers and data, it’s dark magic at that.)
But even as someone who writes programs all day long, there are often tools, algorithms, or ideas we use that we don’t really understand–just take at face value, because well, they seem to work, we’re busy, and learning doesn’t seem necessary when the problem is already sufficiently solved.
One of the more prevalent algorithms of this sort is Gradient Descent (GD). The algorithm is both conceptually simple (everyone likes to show rudimentary sketches of a blindfolded stick figure walking down a mountain) and mathematically rigorous (next to those simple sketches, we show equations with partial derivatives across n-dimensional vectors mapped to an arbitrarily sized higher-dimensional space).
So most often, after learning about GD, you are sent off into the wild to use it, without ever having programmed it from scratch. From this view, Gradient Descent is a sort of incantation we Supreme Hacker Mage Lords can use to solve complex optimization problems whenever our data is in the right format, and we want a quick fix. (Kind of like Neural Networks and “Deep Learning”…) This is practical for most people who just need to get things done. But it’s also unsatisfying for others (myself included).
However, GD can also be implemented in just a few lines of code (even though it won’t be as highly optimized as an industrial-strength version).
That’s why I’m sharing some implementations of both Univariate and generalized Multivariate Gradient Descent written in simple and annotated Python.
Anyone curious about a working implementation (and with some test data in hand) can try this out to experiment. The code snippets below have print statements built in so you can see how your model changes every iteration.
To download and run the full repo, clone it from here: https://github.com/adpoe/Gradient_Descent_From_Scratch
But the actual algorithms are also extracted below, for ease of reading.
Requires NumPy.
Also Requires data to be in this this format: [(x1,y1), (x2,y2) … (xn,yn)], where Y is the actual value.
def gradient_descent(training_examples, alpha=0.01): | |
""" | |
Apply gradient descent on the training examples to learn a line that fits through the examples | |
:param examples: set of all examples in (x,y) format | |
:param alpha = learning rate | |
:return: | |
""" | |
# initialize w0 and w1 to some small value, here just using 0 for simplicity | |
w0 = 0 | |
w1 = 0 | |
# repeat until "convergence", meaning that w0 and w1 aren't changing very much | |
# –> need to define what 'not very much' means, and that may depend on problem domain | |
convergence = False | |
while not convergence: | |
# initialize temporary variables, and set them to 0 | |
delta_w0 = 0 | |
delta_w1 = 0 | |
for pair in training_examples: | |
# grab our data points from the example | |
x_i = pair[0] | |
y_i = pair[1] | |
# calculate a prediction, and find the error | |
h_of_x_i = model_prediction(w0,w1,x_i) | |
delta_w0 += prediction_error(w0,w1, x_i, y_i) | |
delta_w1 += prediction_error(w0,w1,x_i,y_i)*x_i | |
# store previous weighting values | |
prev_w0 = w0 | |
prev_w1 = w1 | |
# get new weighting values | |
w0 = w0 + alpha*delta_w0 | |
w1 = w1 + alpha*delta_w1 | |
alpha -= 0.001 | |
# every few iterations print out current model | |
# 1. –> (w0 + w1x1 + w2x2 + … + wnxn) | |
print "Current model is: ("+str(w0)+" + "+str(w1)+"x1)" | |
# 2. –> averaged squared error over training set, using the current line | |
summed_error = sum_of_squared_error_over_entire_dataset(w0, w1, training_examples) | |
avg_error = summed_error/len(training_examples) | |
print "Average Squared Error="+str(avg_error) | |
# check if we have converged | |
if abs(prev_w0 – w0) < 0.00001 and abs(prev_w1 – w1) < 0.00001: | |
convergence = True | |
# after convergence, print out the parameters of the trained model (w0, … wn) | |
print "Parameters of trained model are: w0="+str(w0)+", w1="+str(w1) | |
return w0, w1 | |
############################ | |
##### TRAINING HELPERS ##### | |
############################ | |
def model_prediction(w0, w1, x_i): | |
return w0 + (w1 * x_i) | |
def prediction_error(w0, w1, x_i, y_i): | |
# basically, we just take the true value (y_i) | |
# and we subtract the predicted value from it | |
# this gives us an error, or J(w0,w1) value | |
return y_i – model_prediction(w0, w1, x_i) | |
def sum_of_squared_error_over_entire_dataset(w0, w1, training_examples): | |
# find the squared error over the whole training set | |
sum = 0 | |
for pair in training_examples: | |
x_i = pair[0] | |
y_i = pair[1] | |
sum += prediction_error(w0,w1,x_i,y_i) ** 2 | |
return sum |
Requires NumPy, same as above.
Also Requires data to be in this this format: [(x1,..xn, y),(x1,..xn, y) …(x1,..xn, y)], where Y is the actual value. Essentially, you can have as many x-variables as you’d like, as long as the y-value is the last element of each tuple.
def multivariate_gradient_descent(training_examples, alpha=0.01): | |
""" | |
Apply gradient descent on the training examples to learn a line that fits through the examples | |
:param examples: set of all examples in (x,y) format | |
:param alpha = learning rate | |
:return: | |
""" | |
# initialize the weight and x_vectors | |
W = [0 for index in range(0, len(training_examples[0][0]))] | |
# W_0 is a constant | |
W_0 = 0 | |
# repeat until "convergence", meaning that w0 and w1 aren't changing very much | |
# –> need to define what 'not very much' means, and that may depend on problem domain | |
convergence = False | |
while not convergence: | |
# initialize temporary variables, and set them to 0 | |
deltaW_0 = 0 | |
deltaW_n = [0 for x in range(0,len(training_examples[0][0]))] | |
for pair in training_examples: | |
# grab our data points from the example | |
x_i = pair[0] | |
y_i = pair[1] | |
# calculate a prediction, and find the error | |
# needs to be an element-wise plus | |
deltaW_0 += multivariate_prediction_error(W_0, y_i, W, x_i) | |
deltaW_n = numpy.multiply(numpy.add(deltaW_n, multivariate_prediction_error(W_0, y_i, W, x_i)), x_i) | |
#print "DELTA_WN = " + str(deltaW_n) | |
# store previous weighting values | |
prev_w0 = W_0 | |
prev_Wn = W | |
# get new weighting values | |
W_0 = W_0 + alpha*deltaW_0 | |
W = numpy.add(W,numpy.multiply(alpha,deltaW_n)) | |
alpha -= 0.001 | |
# every few iterations print out current model | |
# 1. –> (w0 + w1x1 + w2x2 + … + wnxn) | |
variables = [( str(W[i]) + "*x" + str(i+1) + " + ") for i in range(0,len(W))] | |
var_string = ''.join(variables) | |
var_string = var_string[:–3] | |
print "Current model is: " + str(W_0)+" + "+var_string | |
# 2. –> averaged squared error over training set, using the current line | |
summed_error = sum_of_squared_error_over_entire_dataset(W_0, W, training_examples) | |
avg_error = summed_error/len(training_examples) | |
print "Average Squared Error="+str(sum(avg_error)) | |
print "" | |
# check if we have converged | |
if abs(prev_w0 – W_0) < 0.00001 and abs(numpy.subtract(prev_Wn, W)).all() < 0.00001: | |
convergence = True | |
# after convergence, print out the parameters of the trained model (w0, … wn) | |
variables = [( "w"+str(i+1)+"="+str(W[i])+", ") for i in range(0,len(W))] | |
var_string = ''.join(variables) | |
var_string = var_string[:–2] | |
print "RESULTS: " | |
print "\tParameters of trained model are: w0="+str(W_0)+", "+var_string | |
return W_0, W | |
################################ | |
##### MULTIVARIATE HELPERS ##### | |
################################ | |
# generalize these to just take a w0, a vector of weights, and a vector x-values | |
def multivariate_model_prediction(w0, weights, xs): | |
return w0 + numpy.dot(weights, xs) | |
# again, this needs to take just a w0, vector of weights, and a vector of x-values | |
def multivariate_prediction_error(w0, y_i, weights, xs): | |
# basically, we just take the true value (y_i) | |
# and we subtract the predicted value from it | |
# this gives us an error, or J(w0,w1) value | |
return y_i – multivariate_model_prediction(w0, weights, xs) | |
# should be the same, but use the generalize functions above, and update the weights inside the vector titself | |
# also need to have a vector fo delta_Wn values to simplify | |
def multivariate_sum_of_squared_error_over_entire_dataset(w0, weights, training_examples): | |
# find the squared error over the whole training set | |
sum = 0 | |
for pair in training_examples: | |
x_i = pair[0] | |
y_i = pair[1] | |
# cast back to values in range [1 –> 20] | |
prediction = multivariate_model_prediction(w0,weights,x_i) / (1/20.0) | |
actual = y_i / (1/20.0) | |
error = abs(actual – prediction) | |
error_sq = error ** 2 | |
sum += error_sq | |
return sum |
My data set is included in the full repo. But feel free to try it on your own, if you’re experimenting with this. And enjoy.
In which I peel back the curtain and outline the innerworkings of a particularly insidious artificial intelligence, whose sole purpose in life is to systematically learn the optimal strategy for a terrifyingly addictive video game, known only to the internet as: Flappy Bird… and in which I also provide code to program a similar AI of your own.
More pointedly, this short post outlines a practical way to get started using a Reinforcement Learning technique called Q-Learning, as applied to a Python Flappy Bird clone, programmed by @TimoWilken.
>> Grab the code base: https://github.com/adpoe/Flappy-AI <<
So you want to beat Flappy Bird, but after awhile it gets tedious. I agree. Instead, why don’t we program an AI to do it for us? A genius plan, but where do we start?
First, we need a Flappy Bird game to hack upate. The candidate that I suggest is a Python implementation created by Timo Wilken and available for download directly at: https://github.com/TimoWilken/flappy-bird-pygame. This Flappy Bird version is implemented using the PyGame library, which is a dependency going forward.
Here are instructions for PyGame installation. If you get this runing, the hard work is done. (apt-get or homebrew are highly recommended.)
The first challenge we’ll have in implementing the framework for a Flappy AI is determining exactly how the game workes in its original state.
By using the debugger and stepping through the game’s code during some trial runs, I was able to figure out where key decisions where made, how data flowed into the game, and exactly where I would need to position my AI-agent.
At its basic level, I created an “Agent” class, and passed that class into the running game code. Then, at each loop of the game, I examined the variables available to me, and then passed a ‘MOUSEBUTTONUP’ command to the PyGame event queue whenever the AI decided to jump. Otherwise, I did nothing.
From there, the next step was determining a way to model the problem. I decided to use follow the basic guidelines outlined by Sarvagya Vaish, here.
First, I discretized the space in which the bird sat, relative to the next pipe. I was able to get pipe data by accessing the pipe object in the original game code. Similarly, I was able to get bird data by accessing the bird object.
From there, I could determine the location of the bird and the pipes relative to each other. I discretized this space as a 25×25 grid, with the following parameters:
# first value in state tuple | |
height_category = 0 | |
dist_to_pipe_bottom = pipe_bottom – bird.y | |
if dist_to_pipe_bottom < 8: # very close | |
height_category = 0 | |
elif dist_to_pipe_bottom < 20: # close | |
height_category = 1 | |
elif dist_to_pipe_bottom < 125: #mid | |
height_category = 2 | |
elif dist_to_pipe_bottom < 250: # far | |
height_category = 3 | |
else: | |
height_category = 4 | |
# second value in state tuple | |
dist_category = 0 | |
dist_to_pipe_horz = pp.x – bird.x | |
if dist_to_pipe_horz < 8: # very close | |
dist_category = 0 | |
elif dist_to_pipe_horz < 20: # close | |
dist_category = 1 | |
elif dist_to_pipe_horz < 125: # mid | |
dist_category = 2 | |
elif dist_to_pipe_horz < 250: # far | |
dist_category = 3 | |
else: | |
dist_category = 4 |
Using this methodology, I created a state tuple that looked like this:
(height_category={0,1,2,3,4}, dist_category={0,1,2,3,4} , collision=True/False)
Then, each iteration of the game loop, I was able to determine the bird’s relative position, and whether it had made a collision with the pipes or not.
If there was no collision, I issued a reward of +1.
If there was a collision, I issued a reward of -1000.
I tried many different state representations here, but mostly it was matter of determining an optimal number of grid spaces and the right parameters for those spaces.
Initially, I started with a 9×9 grid, but moved to 16×16 because I got to a point in 9×9 where I just couldn’t make any more learning progress.
Very generally, we want to have a tighter grid around the pipes, as this is where most collisions happen. And we want a looser grid as we move outwards. This seemed to give me the best results, as we need different strategies at different locations on the grid.
Our next task is implementing an exploration approach. This is necessary because if we don’t randomly explore the state sometimes, there might be optimal strategies that we are never able to find, simply because we will never be in those states!
Because we have only two choices at any given state (JUMP—or—STAY), implementing exploration was relatively simple.
I started out with a high exploration factor (I used 1/time_value+1), and then I generated a random number between [0,1). If the random number was less than the exploration factor, then I explored.
Over time the exploration factor got lower, and therefore the AI explored less frequently.
Exploration essentially consisted of flipping a fair coin (generating a Boolean value randomly).
The main problem I encountered with this method is that the exploration factor was very at the beginning, and sometimes choices were made that were not representative of actual situations that the bird would encounter in ‘true’ gameplay.
BUT, because these decisions were made earlier, they were weighted more heavily in the overall Q-Learning algorithm.
This isn’t ideal, but exploration is necessary, and overall the algorithm works well. So it wasn’t a large problem, overall.
Very simply, “Learning Rates” dicatate how much we weigh new information about some state over old information. Learning rates can be an value in the range [0,1]. With 0 meaning we never update values (bad), and 1 meaning that we only EVER care about what happened the last time we were in state (short-sighted).
The first learning rate I tried was alpha=(1/time+1). However, this gave very poor results in practice.
This is because time is NOT the most important factor in determining a strategy from any given state. Rather, it is how many times we’ve been to that state.
The problem is that we make extremely poor choices at the beginning of the game (because we simply don’t know any better). But with alpha=(1/time+1), the results of these these poor choices are weighted the most highly.
Once I changed the learning factor to alpha=1/N(s,a), I immediately saw dramatically better results. (That is, where N(s,a) tracks how many times we’ve been in a given state and performed the same action.)
My final, “Smart” bird is the result of about 4 hours of training.
I don’t actually think there would be a way to make the training more efficient, aside from speeding up the gameplay in some way.
Overall, I the results I received from the investment of time I put it in reasonable.
Given more time, I would probably discretize the space even more finely (maybe a 36×36 grid) – so that I could find even more optimal strategies from a more fine-tuned set of positions in the game-space.
To use my smart bird, simply take the following steps:
Probably more instructive than using my trained bird though, is to simply start training a new bird from scratch. You will see the agony and the ecstasy as he does a terrible number of dumb things, slowly learning how to beat the game.
It’s surprisingly enjoyable (though sometimes frustrating) and highly recommended. Start the process by running:
With no other args. Then pass in the ‘qdata.txt’ file next time you run the game, to keep your learning session going.
I consulted the following resources to implement my AI. If you want to do similar work, I’d recommend these resources. These people are much smarter than me. I’m just applying their concepts.
Recently, I finished an artificial intelligence project that involved implementing the Minimax and Alpha-Beta pruning algorithms in Python.
These algorithms are standard and useful ways to optimize decision making for an AI-agent, and they are fairly straightforward to implement.
I haven’t seen any actual working implementations of these using Python yet, however. So I’m posting my code as an example for future programmers to improve & expand upon.
It’s also useful to see a working implementation of abstract algorithms sometimes, when you’re seeking greater intuition about how they work in practice.
My hope is that this post provides you with some of that intuition, should you need it–and that it does so at an accelerated pace.
Let’s start with Minimax itself.
Assumptions: This code assumes you have already built a game tree relevant to your problem, and now your task is to parse it. If you haven’t yet built a game tree, that will be the first step in this process. I have a previous post about how I did it for my own problem, and you can use that as a starting point. But keep in mind that YMMV.
My implementation looks like this:
########################## | |
###### MINI-MAX ###### | |
########################## | |
class MiniMax: | |
# print utility value of root node (assuming it is max) | |
# print names of all nodes visited during search | |
def __init__(self, game_tree): | |
self.game_tree = game_tree # GameTree | |
self.root = game_tree.root # GameNode | |
self.currentNode = None # GameNode | |
self.successors = [] # List of GameNodes | |
return | |
def minimax(self, node): | |
# first, find the max value | |
best_val = self.max_value(node) # should be root node of tree | |
# second, find the node which HAS that max value | |
# –> means we need to propagate the values back up the | |
# tree as part of our minimax algorithm | |
successors = self.getSuccessors(node) | |
print "MiniMax: Utility Value of Root Node: = " + str(best_val) | |
# find the node with our best move | |
best_move = None | |
for elem in successors: # —> Need to propagate values up tree for this to work | |
if elem.value == best_val: | |
best_move = elem | |
break | |
# return that best value that we've found | |
return best_move | |
def max_value(self, node): | |
print "MiniMax–>MAX: Visited Node :: " + node.Name | |
if self.isTerminal(node): | |
return self.getUtility(node) | |
infinity = float('inf') | |
max_value = –infinity | |
successors_states = self.getSuccessors(node) | |
for state in successors_states: | |
max_value = max(max_value, self.min_value(state)) | |
return max_value | |
def min_value(self, node): | |
print "MiniMax–>MIN: Visited Node :: " + node.Name | |
if self.isTerminal(node): | |
return self.getUtility(node) | |
infinity = float('inf') | |
min_value = infinity | |
successor_states = self.getSuccessors(node) | |
for state in successor_states: | |
min_value = min(min_value, self.max_value(state)) | |
return min_value | |
# # | |
# UTILITY METHODS # | |
# # | |
# successor states in a game tree are the child nodes… | |
def getSuccessors(self, node): | |
assert node is not None | |
return node.children | |
# return true if the node has NO children (successor states) | |
# return false if the node has children (successor states) | |
def isTerminal(self, node): | |
assert node is not None | |
return len(node.children) == 0 | |
def getUtility(self, node): | |
assert node is not None | |
return node.value |
How-to: To use this code, create a new instance of the Minimax object, and pass in your GameTree object. This code should work on any GameTree object that has fields for: 1) child nodes; 2) value. (That is, unless I made an error, which of course, is very possible)
After the Minimax object is instantiated, run the minimax() function, and you will see a trace of the program’s output, as the algorithm evaluates each node in turn, before choosing the best possible option.
What you’ll notice: Minimax needs to evaluate **every single node** in your tree. For a small tree, that’s okay. But for a huge AI problem with millions of possible states to evaluate (think: Chess, Go, etc.), this isn’t practical.
How we solve: To solve the problem of looking at every single node, we can implement a pruning improvement to Minimax, called Alpha-Beta.
Alpha-Beta Pruning Improvement
Essentially, Alpha-Beta pruning works keeping track of the best/worst values seen as the algorithm traverses the tree.
Then, if ever we get to a node with a child who has a higher/lower value which would disqualify it as an option–we just skip ahead.
Rather than going into a theoretical discussion of WHY Alpha-Beta works, this post is focused on the HOW. For me, it’s easier to see the how and work backwards to why. So here’s the quick and dirty implementation.
########################## | |
###### MINI-MAX A-B ###### | |
########################## | |
class AlphaBeta: | |
# print utility value of root node (assuming it is max) | |
# print names of all nodes visited during search | |
def __init__(self, game_tree): | |
self.game_tree = game_tree # GameTree | |
self.root = game_tree.root # GameNode | |
return | |
def alpha_beta_search(self, node): | |
infinity = float('inf') | |
best_val = –infinity | |
beta = infinity | |
successors = self.getSuccessors(node) | |
best_state = None | |
for state in successors: | |
value = self.min_value(state, best_val, beta) | |
if value > best_val: | |
best_val = value | |
best_state = state | |
print "AlphaBeta: Utility Value of Root Node: = " + str(best_val) | |
print "AlphaBeta: Best State is: " + best_state.Name | |
return best_state | |
def max_value(self, node, alpha, beta): | |
print "AlphaBeta–>MAX: Visited Node :: " + node.Name | |
if self.isTerminal(node): | |
return self.getUtility(node) | |
infinity = float('inf') | |
value = –infinity | |
successors = self.getSuccessors(node) | |
for state in successors: | |
value = max(value, self.min_value(state, alpha, beta)) | |
if value >= beta: | |
return value | |
alpha = max(alpha, value) | |
return value | |
def min_value(self, node, alpha, beta): | |
print "AlphaBeta–>MIN: Visited Node :: " + node.Name | |
if self.isTerminal(node): | |
return self.getUtility(node) | |
infinity = float('inf') | |
value = infinity | |
successors = self.getSuccessors(node) | |
for state in successors: | |
value = min(value, self.max_value(state, alpha, beta)) | |
if value <= alpha: | |
return value | |
beta = min(beta, value) | |
return value | |
# # | |
# UTILITY METHODS # | |
# # | |
# successor states in a game tree are the child nodes… | |
def getSuccessors(self, node): | |
assert node is not None | |
return node.children | |
# return true if the node has NO children (successor states) | |
# return false if the node has children (successor states) | |
def isTerminal(self, node): | |
assert node is not None | |
return len(node.children) == 0 | |
def getUtility(self, node): | |
assert node is not None | |
return node.value |
How-to: This algorithm works the same as Minimax. Instantiate a new object with your GameTree as an argument, and then call alpha_beta_search().
What you’ll notice: Alpha-Beta pruning will always give us the same result as Minimax (if called on the same input), but it will require evaluating far fewer nodes. Tracing through the code will illustrate why.
This isn’t the most robust implementation of either algorithm (in fact it’s deficient in many ways), so I wouldn’t recommend it for industrial use.
However, this code should simply illustrate how each algorithm works, and it will provide output you can trace through and compare against–as long as you are able to construct the GameTree for your problem.
From there, it’s only a matter of time until you’ll understand it intuitively. This is one of those things that took a little while for me to grasp–so hopefully having a clear example will help others get there more quickly. Good luck.
I’ve been working on a AI project today and came across this problem.
Given input data structured like so:
[‘A’, [‘B’, (‘D’, 3), (‘E’, 5)], [‘C’, [‘F’, [‘I’,(‘K’,0), (‘L’, 7)],(‘J’,5)], [‘G’, (‘M’,7), (‘N’,8)], (‘H’,4)]]
I need to parse and build tree which has an arbitrary branching factor, and values only at the leaves.
(As for why: Later, I’ll be running Minimax and some other algorithms on this tree, in order to algorithmically determine the best possible game move. More on that in another post.)
This seemed like a good problem to solve recursively. And to avoid a soul-sucking debug session, I decided my goal was to solve it as succinctly as possible.
Here’s what I came up with. Why I’m posting: This seems like it would be a very common AI/Data-Structures problem, but my first few searches on the subject came up with nada. Nothing even closely related to the problem I’m solving. So doing my part to fix that now.
""" @author Tony Poerio | |
@email tony@tonypoer.io | |
tree_parser.py –> parse a nested data string into a tree. | |
Only leaf nodes have values. | |
I'm intending to running minimax algorithms on these trees for a competitive game AI | |
Data should be in the following format: | |
['A', ['B', ('D', 3), ('E', 5)], ['C', ['F', ['I',('K',0), ('L', 7)],('J',5)], ['G', ('M',7), ('N',8)], ('H',4)]] | |
Note that Leaves must be **tuples** | |
Usage: python tree_parser.py [filename] | |
File should have data in the format shown above. | |
""" | |
from ast import literal_eval | |
import sys | |
########################## | |
###### PARSE DATA ######## | |
########################## | |
def parse_data_as_list(fname): | |
with open(fname, "r") as f: | |
data_as_string = f.read() | |
print data_as_string | |
data_list = literal_eval(data_as_string) | |
return data_list | |
class GameNode: | |
def __init__(self, name, value=0, parent=None): | |
self.Name = name # a char | |
self.value = value # an int | |
self.parent = parent # a node reference | |
self.children = [] # a list of nodes | |
def addChild(self, childNode): | |
self.children.append(childNode) | |
class GameTree: | |
def __init__(self): | |
self.root = None | |
def build_tree(self, data_list): | |
""" | |
:param data_list: Take data in list format | |
:return: Parse a tree from it | |
""" | |
self.root = GameNode(data_list.pop(0)) | |
for elem in data_list: | |
self.parse_subtree(elem, self.root) | |
def parse_subtree(self, data_list, parent): | |
# base case | |
if type(data_list) is tuple: | |
# make connections | |
leaf_node = GameNode(data_list[0]) | |
leaf_node.parent = parent | |
parent.addChild(leaf_node) | |
# if we're at a leaf, set the value | |
if len(data_list) == 2: | |
leaf_node.value = data_list[1] | |
return | |
# recursive case | |
tree_node = GameNode(data_list.pop(0)) | |
# make connections | |
tree_node.parent = parent | |
parent.addChild(tree_node) | |
for elem in data_list: | |
self.parse_subtree(elem, tree_node) | |
# return from entire method if base case and recursive case both done running | |
return | |
########################## | |
#### MAIN ENTRY POINT #### | |
########################## | |
def main(): | |
filename = sys.argv[1] | |
print "hello world! " + filename | |
data_list = parse_data_as_list(filename) | |
data_tree = GameTree() | |
data_tree.build_tree(data_list) | |
if __name__ == "__main__": | |
main() |
Side note. I’m actually not sure what this tree (with weights only at the leaves) would be called technically. It reminds me of the tree made during Huffman Encoding, but it’s not quite a match for that since we aren’t summing the values in all parent nodes. If you know the technical name, let me know, so I can update.
You must be logged in to post a comment.