diff --git a/.gitignore b/.gitignore index 5a28fda..4c69b46 100755 --- a/.gitignore +++ b/.gitignore @@ -3,4 +3,6 @@ build *.jar /library -/gradle \ No newline at end of file +/gradle +ga/ +neural_network/ \ No newline at end of file diff --git a/src/main/java/board/Board.java b/src/main/java/board/Board.java index c0b4b6a..5c53f28 100755 --- a/src/main/java/board/Board.java +++ b/src/main/java/board/Board.java @@ -5,9 +5,7 @@ import observer.OnUpdateObserver; import processing.core.PApplet; import util.Config; - import java.util.ArrayList; -import java.util.Arrays; import java.util.List; public class Board { @@ -37,11 +35,9 @@ public void setOnRefreshListener(Observer observer) { /* SNAKES CONTROL */ void makeSnakeStep() { - for (Snake snake : snakeList) { + for (int i = 0; i < snakeList.size(); i++) { + Snake snake = snakeList.get(i); snake.makeStep(); - int[] vision = snake.getVision(); - System.out.println(Arrays.toString(vision)); - } } @@ -64,7 +60,8 @@ public void draw() { } void drawSnakes() { - for (Snake snake : snakeList) { + for (int i = 0; i < snakeList.size(); i++) { + Snake snake = snakeList.get(i); snake.drawSnake(); } } diff --git a/src/main/java/board/snake/Snake.java b/src/main/java/board/snake/Snake.java index e5dfbe8..c58bd69 100755 --- a/src/main/java/board/snake/Snake.java +++ b/src/main/java/board/snake/Snake.java @@ -21,10 +21,11 @@ public class Snake { private boolean canChangeDirection = true; private boolean snakeIncrease = true; - private boolean snakeFinished = false; + protected boolean snakeFinished = false; private final Food food; - private int score; + protected int score; + protected int steps; private final List onFinishObservers; private final List observerList; @@ -35,7 +36,9 @@ public Snake(int[][] boardMatrix) { this.snakeCellsHashSet = new HashSet<>(); this.onFinishObservers = new ArrayList<>(); this.observerList = new ArrayList<>(); + this.score = 0; + this.steps = 0; this.food = new Food(boardMatrix); initialize(); @@ -49,7 +52,7 @@ public void initialize() { direction[0] = Util.generateRandom(-1, 2); direction[1] = 0; if (direction[0] == 0) { - direction[1] = Util.generateRandom(-1, 2); + direction[1] = 1; } initializeSnakePosition(); @@ -62,8 +65,11 @@ public void initialize() { } private void initializeSnakePosition() { - int currentI = Config.BOARD_ROWS / 2 + Util.generateRandom(-7, 7); - int currentJ = Config.BOARD_COLUMNS / 2 + Util.generateRandom(-7, 7); +// int currentI = Config.BOARD_ROWS / 2 + Util.generateRandom(-7, 7); +// int currentJ = Config.BOARD_COLUMNS / 2 + Util.generateRandom(-7, 7); + + int currentI = Config.BOARD_ROWS / 2; + int currentJ = Config.BOARD_COLUMNS / 2; if (Config.BOARD_COLUMNS < 10) { currentI = 2; @@ -129,6 +135,8 @@ public void makeStep() { } canChangeDirection = true; + + steps++; } /* DRAWING LOGIC */ @@ -187,7 +195,7 @@ public void repositionFood() { public void addPoint() { snakeIncrease = true; - score+= 5; + score++; for (Observer observer : observerList) { if (observer instanceof OnFoodEaten) { @@ -235,7 +243,6 @@ public int[] getDirection() { /* Gets a 24 double vector */ public int[] getVision() { - // degrees are in reverse because the coords are in reverse int[] head = direction; int[] west = GeneticAiUtil.reverse(GeneticAiUtil.rotateVectorInt(head[1], head[0], -90)); @@ -264,10 +271,10 @@ public int[] getVision() { int[] vision = new int[8 * 3]; - System.arraycopy(headData, 0, vision, 0, 3); - System.arraycopy(northWestData, 0, vision, 3, 3); - System.arraycopy(westData, 0, vision, 6, 3); - System.arraycopy(southWestData, 0, vision, 9, 3); + System.arraycopy(headData, 0, vision, 0, 3); + System.arraycopy(northWestData, 0, vision, 3, 3); + System.arraycopy(westData, 0, vision, 6, 3); + System.arraycopy(southWestData, 0, vision, 9, 3); System.arraycopy(southData, 0, vision, 12, 3); System.arraycopy(southEastData, 0, vision, 15, 3); System.arraycopy(eastData, 0, vision, 18, 3); @@ -336,7 +343,7 @@ private double[] normalizeDistances(int[] distances) { double[] normalizedDistance = new double[n]; - double max = Math.max(boardMatrix.length, boardMatrix[0].length); + double max = Math.max(boardMatrix.length, boardMatrix[0].length) + 1; for (int i = 0; i < n; i++) { normalizedDistance[i] = 1.0 * distances[i] / max; @@ -385,4 +392,8 @@ public void moveDown() { direction[1] = 0; canChangeDirection = false; } + + public int getSteps() { + return steps; + } } diff --git a/src/main/java/controller/controlers/genetic/GeneticAiController.java b/src/main/java/controller/controlers/genetic/GeneticAiController.java index ad9423b..928f6db 100644 --- a/src/main/java/controller/controlers/genetic/GeneticAiController.java +++ b/src/main/java/controller/controlers/genetic/GeneticAiController.java @@ -5,6 +5,7 @@ import ga.AbstractGeneticAlgorithm; import ga.config.GaConfig; import ga.lambda.BeforeEvaluationEvent; +import ga.lambda.observers.OnNewGeneration; import ga.member.AbstractMember; import ga.operators.crossover.AbstractCrossover; import ga.operators.crossover.OnePointCrossover; @@ -12,11 +13,10 @@ import ga.operators.mutation.SimpleMutation; import ga.operators.selection.AbstractSelection; import ga.operators.selection.TournamentSelection; -import neural_network.NeuralNetwork; import processing.core.PApplet; +import util.Config; -import java.util.ArrayList; -import java.util.List; +import java.util.Collections; public class GeneticAiController extends SnakeController { @@ -27,17 +27,14 @@ public GeneticAiController(PApplet pApplet) { @Override public void run() { - int[] layerSize = new int[]{3, 4, 5, 3}; - int geneLength = SmartSnake.calculateGeneLength(layerSize); - - NeuralNetwork neuralNetwork = new NeuralNetwork(layerSize); + int geneLength = SmartSnake.calculateGeneLength(Config.LAYER_SIZE); GaConfig gaConfig = GaConfig.initializeWithParameters( - 100, + 800, 1000, 1, geneLength, - 0.0015, + 0.0005, 0.02 ); @@ -52,21 +49,18 @@ public void run() { abstractSelection ); - List smartSnakes = new ArrayList<>(); - - BeforeEvaluationEvent beforeEvaluationEvent = (population) -> { - for (AbstractMember abstractMember : population) { + BeforeEvaluationEvent beforeEvaluationEvent = (copyReferencePopulation) -> { + /* Put the population on the board */ + snakeList.clear(); + for (AbstractMember abstractMember : copyReferencePopulation) { short[] gene = abstractMember.getGeneCopy(); - Snake snake = new Snake(super.board.getBoardMatrix()); - SmartSnake smartSnake = new SmartSnake(snake, gene, layerSize); - - smartSnakes.add(smartSnake); + SmartSnake smartSnake = new SmartSnake(super.board.getBoardMatrix(), gene, Config.LAYER_SIZE); - snakeList.clear(); - snakeList.add(smartSnake.getSnake()); + snakeList.add(smartSnake); } + /* Wait until the all of them are done */ while (!board.allSnakesFinished()) { try { Thread.sleep(100); @@ -75,34 +69,31 @@ public void run() { } } + /* TODO: implement a method in GA to print all scores */ + + /* Calculate the score (= fitness) */ for (int i = 0; i < snakeList.size(); i++) { Snake snake = snakeList.get(i); - AbstractMember abstractMember = population.get(i); + AbstractMember abstractMember = copyReferencePopulation.get(i); - ((Member) abstractMember).setScore(snake.getScore()); + ((Member) abstractMember).setScore(snake.getScore() * 50 + snake.getSteps() / 10 + 1); + abstractMember.calculateFitness(); } }; - new Thread(abstractGeneticAlgorithm::start).start(); + OnNewGeneration onNewGeneration = ((copyPopulation, generation) -> { + copyPopulation.sort(Collections.reverseOrder()); - try { - Thread.sleep(1000); - } catch (InterruptedException e) { - e.printStackTrace(); - } + System.out.printf("Generation %d. Best fitness %6.2f\n", generation, copyPopulation.get(0).getFitness()); + }); - super.board.start(); + abstractGeneticAlgorithm.addPopulationObserver(onNewGeneration); - while (!board.allSnakesFinished()) { - for (SmartSnake smartSnake : smartSnakes) { - Snake snake = smartSnake.getSnake(); + abstractGeneticAlgorithm.addBeforeEvaluationObserver(beforeEvaluationEvent); - if (!snake.isFinished()) { - smartSnake.predictNext(); - } - } - } + super.board.start(); + new Thread(abstractGeneticAlgorithm::start).start(); } } diff --git a/src/main/java/controller/controlers/genetic/GeneticAiUtil.java b/src/main/java/controller/controlers/genetic/GeneticAiUtil.java index c507fa5..390f19f 100644 --- a/src/main/java/controller/controlers/genetic/GeneticAiUtil.java +++ b/src/main/java/controller/controlers/genetic/GeneticAiUtil.java @@ -115,4 +115,14 @@ public static int[] crushedDiagonalsToDirection(double[] dir) { return newDir; } + public static double getMax(double[] vector) { + double max = 0; + + for (double x : vector) { + max = Math.max(x, max); + } + + return max; + } + } diff --git a/src/main/java/controller/controlers/genetic/GeneticAlgorithm.java b/src/main/java/controller/controlers/genetic/GeneticAlgorithm.java index 1f9b884..9cde8b9 100644 --- a/src/main/java/controller/controlers/genetic/GeneticAlgorithm.java +++ b/src/main/java/controller/controlers/genetic/GeneticAlgorithm.java @@ -2,11 +2,17 @@ import ga.AbstractGeneticAlgorithm; import ga.config.GaConfig; +import ga.conversion.RangeDoubleToInterval; import ga.member.AbstractMember; import ga.operators.crossover.AbstractCrossover; import ga.operators.mutation.AbstractMutation; import ga.operators.selection.AbstractSelection; +import neural_network.NeuralNetwork; +import util.Config; +import java.util.ArrayList; +import java.util.Collection; +import java.util.Collections; import java.util.List; public class GeneticAlgorithm extends AbstractGeneticAlgorithm { @@ -16,11 +22,47 @@ public GeneticAlgorithm(GaConfig gaConfig, AbstractMutation abstractMutation, Ab @Override public List generatePopulation() { - return null; + List population = new ArrayList<>(); + + for (int i = 0; i < gaConfig.populationSize; i++) { + NeuralNetwork neuralNetwork = new NeuralNetwork(Config.LAYER_SIZE); + + double[][][] brain = neuralNetwork.getBrainReference(); + double[] vector = GeneticAiUtil.brainToVector(brain, Config.LAYER_SIZE); + short[] bitmap = RangeDoubleToInterval.toBitMapVector(vector, Config.START, Config.END, Config.PRECISION); + + population.add(new Member(bitmap)); + } + + return population; } @Override public List getAbstractMembersFromGene(List list, GaConfig gaConfig, AbstractMutation abstractMutation) { - return null; + List members = new ArrayList<>(); + + for (short[] gene : list) { + members.add(new Member(gene)); + } + + return members; + } + + @Override + public void selectPopulation(List population) { + List bestNThMembers = new ArrayList<>(); + List copyPopulation = getPopulationCopy(population); + + super.selectPopulation(population); + + copyPopulation.sort(Collections.reverseOrder()); + + int min = Math.min(5, population.size()); + + for (int i = 0; i < min; i++) { + bestNThMembers.add(copyPopulation.get(i)); + } + + population.addAll(bestNThMembers); } } diff --git a/src/main/java/controller/controlers/genetic/Member.java b/src/main/java/controller/controlers/genetic/Member.java index e79342b..70e9b9c 100644 --- a/src/main/java/controller/controlers/genetic/Member.java +++ b/src/main/java/controller/controlers/genetic/Member.java @@ -23,12 +23,17 @@ public void calculateScore() { } - public void setScore(int score) { + public void setScore(double score) { super.score = score; } @Override public AbstractMember getCopy() { - return new Member(Arrays.copyOf(gene, gene.length)); + Member member = new Member(Arrays.copyOf(gene, gene.length)); + + member.setScore(score); + member.calculateFitness(); + + return member; } } diff --git a/src/main/java/controller/controlers/genetic/SmartSnake.java b/src/main/java/controller/controlers/genetic/SmartSnake.java index 84af3e4..be3c3e5 100644 --- a/src/main/java/controller/controlers/genetic/SmartSnake.java +++ b/src/main/java/controller/controlers/genetic/SmartSnake.java @@ -3,46 +3,72 @@ import board.snake.Snake; import ga.conversion.RangeDoubleToInterval; import neural_network.NeuralNetwork; - -public class SmartSnake { - private final Snake snake; - +import neural_network.activation.ActivationFunction; +import neural_network.activation.SigmoidFunction; +import neural_network.activation.TanhFunction; +import neural_network.bias.BiasInit; +import neural_network.bias.NullBias; +import neural_network.cost.CostFunction; +import neural_network.cost.SimpleCost; +import neural_network.weights.WeightsInit; +import neural_network.weights.XavierWeightsInit; +import util.Config; + +import java.util.Arrays; + +public class SmartSnake extends Snake { private final NeuralNetwork neuralNetwork; - private final short[] gene; - /* I know, I know, uppercase */ - private static final double start = -10; - private static final double end = 10; - private static final int precision = 3; + private static ActivationFunction activationFunction = new SigmoidFunction(); + private static WeightsInit weightsInit = new XavierWeightsInit(); + private static BiasInit biasInit = new NullBias(); + private static CostFunction costFunction = new SimpleCost(); +// private static LossFunction lossFunction = new MeanSquaredError(); + + private int consecutiveSteps = 0; + private int oldScore = 0; - public SmartSnake(Snake snake, short[] gene, int[] layerSize) { - this.snake = snake; - this.gene = gene; + public SmartSnake(int[][] board, short[] gene, int[] layerSize) { + super(board); - double[] weights = RangeDoubleToInterval.toDoubleVector(gene, start, end, precision); + double[] weights = RangeDoubleToInterval.toDoubleVector(gene, Config.START, Config.END, Config.PRECISION); double[][][] brain = GeneticAiUtil.vectorToBrain(weights, layerSize); - this.neuralNetwork = new NeuralNetwork(brain, layerSize); + this.neuralNetwork = new NeuralNetwork(brain, layerSize, activationFunction, weightsInit, biasInit, costFunction); + + oldScore = score; } - public SmartSnake(Snake snake, int[] layerSize) { - this.snake = snake; - this.neuralNetwork = new NeuralNetwork(layerSize); - double[][][] brain = neuralNetwork.getBrainReference(); + /* 8 directions: for each one sees the distance between the walls, tail and food, in total 24 */ + public void predictNext() { + double[] input = getNormalizedVision(); + double[] output = neuralNetwork.feedForward(input); - double[] weights = GeneticAiUtil.brainToVector(brain, layerSize); + double maxProb = GeneticAiUtil.getMax(output); - this.gene = RangeDoubleToInterval.toBitMapVector(weights, start, end, precision); - } +// System.out.println("Input: " + Arrays.toString(input)); +// System.out.println("Probs: " + Arrays.toString(output)); - public Snake getSnake() { - return snake; - } + /* UP */ + if (maxProb == output[0]) { + moveUp(); + } + + /* LEFT */ + if (maxProb == output[1]) { + moveLeft(); + } + + /* DOWN */ + if (maxProb == output[2]) { + moveDown(); + } + + /* RIGHT */ + if (maxProb == output[3]) { + moveRight(); + } - /* 8 directions: for each one sees the distance between the walls, tail and food, in total 24 */ - public void predictNext() { - double[] input; - double[] output; } public static int calculateGeneLength(int[] layerSize) { @@ -52,9 +78,26 @@ public static int calculateGeneLength(int[] layerSize) { elements += layerSize[i] * layerSize[i + 1]; } - int lengthPerElement = RangeDoubleToInterval.calculateBitPointLength(end - start, precision); + int lengthPerElement = RangeDoubleToInterval.calculateBitPointLength(Config.END - Config.START, Config.PRECISION); return elements * lengthPerElement; } + @Override + public void makeStep() { + predictNext(); + super.makeStep(); + + if (oldScore != score) { + consecutiveSteps = 0; + oldScore = score; + } + + if (consecutiveSteps == 28 * 28) { + snakeFinished = true; + } + + consecutiveSteps++; + } + } diff --git a/src/main/java/util/Config.java b/src/main/java/util/Config.java index 74019f6..0a54346 100755 --- a/src/main/java/util/Config.java +++ b/src/main/java/util/Config.java @@ -1,11 +1,11 @@ package util; public abstract class Config { - public static final int GAME_TYPE = GameType.SINGLE_PLAYER; + public static final int GAME_TYPE = GameType.GENETIC_AI; /* CANVAS */ public static final int FRAMERATE = 60; - public static final int REFRESH_RATE = 10; + public static final int REFRESH_RATE = 1; /* BOARD */ public static final int BOARD_WIDTH = 650; @@ -22,4 +22,13 @@ public abstract class Config { public static final int BOARD_X = 2; public static final int BOARD_Y = 2; + + /* NN & G */ + public static final int[] LAYER_SIZE = new int[]{24, 16, 4}; + + public static final int START = -5; + public static final int END = 5; + + public static final int PRECISION = 3; + } diff --git a/src/main/java/window/MainWindow.java b/src/main/java/window/MainWindow.java index 2ec03e7..f608595 100755 --- a/src/main/java/window/MainWindow.java +++ b/src/main/java/window/MainWindow.java @@ -4,18 +4,17 @@ import controller.controlers.SinglePlayerController; import controller.SnakeController; import controller.controlers.TwoPlayersController; +import controller.controlers.genetic.GeneticAiController; import processing.core.PApplet; import util.Config; import util.GameType; public class MainWindow extends PApplet { - public void settings() { size(Config.CANVAS_WIDTH, Config.CANVAS_HEIGHT); } - public void setup() { // surface.setLocation(displayWidth / 2, displayHeight / 2 - height / 2); surface.setTitle("Conquering The Snake"); @@ -31,12 +30,12 @@ public void setup() { snakeController = switch (gameType) { case GameType.SINGLE_PLAYER -> new SinglePlayerController(this); case GameType.TWO_PLAYERS -> new TwoPlayersController(this); - default -> new HamiltonController(this); + case GameType.HAMILTON_MODE -> new HamiltonController(this); + default -> new GeneticAiController(this); }; snakeController.run(); } - public void draw(){ } }