Skip to content

Commit

Permalink
Initial source commit
Browse files Browse the repository at this point in the history
  • Loading branch information
thunfischtoast committed Jul 22, 2018
1 parent 9a33dec commit e3018d5
Show file tree
Hide file tree
Showing 4 changed files with 515 additions and 0 deletions.
24 changes: 24 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
<?xml version="1.0" encoding="UTF-8"?>
<project xmlns="http://maven.apache.org/POM/4.0.0"
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd">
<modelVersion>4.0.0</modelVersion>

<groupId>thunfischtoast</groupId>
<artifactId>linucb</artifactId>
<version>1.0</version>

<properties>
<maven.compiler.source>1.8</maven.compiler.source>
<maven.compiler.target>1.8</maven.compiler.target>
</properties>

<dependencies>
<dependency>
<groupId>org.apache.commons</groupId>
<artifactId>commons-math3</artifactId>
<version>3.6.1</version>
</dependency>
</dependencies>

</project>
118 changes: 118 additions & 0 deletions src/de/thunfischtoast/BanditTest.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,118 @@
/* Copyright (C) 2018 Christian Römer
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Contact: https://github.com/thunfischtoast or christian.roemer[at]udo.edu
*/

package de.thunfischtoast;

import org.apache.commons.math3.linear.ArrayRealVector;
import org.apache.commons.math3.linear.RealVector;

import java.util.Random;

/**
* Small test class as inspired by John Maxwell (http://john-maxwell.com/post/2017-03-17/). We create a context of two
* features that represent reading preferences of imaginary news site visitors. The visitors like or dislike sites with
* sports content (context[0] = 1 or 0) and like or dislike sites with politics content (context[1] = 1 or 0). The
* features are bound to integers for easier analysis. The bandit should offer one of three sites to the visitor, each
* having different contents to offer for sports and politics.
*
* The algorithm implementation has problems when all context features are 0, as the expected reward will always become 0
* as well. It is currently not apparent if this is a problem of the implementation of the algorithm itself.
*/
public class BanditTest {

public static void main(String[] args) {
HybridLinUCB linUCB = new HybridLinUCB(2, 2, 3, 18);
// LinUCB linUCB = new LinUCB(2, 3, 5);

Random random = new Random(7);

for (int j = 0; j < 2; j++) {
for (int k = 0; k < 2; k++) {
for (int i = 0; i < 3; i++) {
System.out.println("Arm " + i + ", Context (" + j + ", " + k + ") mean is " + getMean(i, new double[]{j, k}));
}
}
}

double maxReward = 0;
double minReward = 1000;
int[][][] counters = new int[3][2][2];

for (int i = 0; i < 10000; i++) {
double sports = random.nextInt(2);
double politics = random.nextInt(2);

ArrayRealVector context = new ArrayRealVector(new double[]{sports, politics});
if(linUCB instanceof HybridLinUCB)
context = context.append(context);

int arm = linUCB.chooseArm(context);

// make sure that rewards are between 0 and 1
double reward = ((nextBoundedGaussian(random) + getMean(arm, context.toArray())) + 1 ) / 2.25;
maxReward = Math.max(maxReward, reward);
minReward = Math.min(minReward, reward);
linUCB.receiveRewards(new RealVector[]{context}, new int[]{arm}, new double[]{reward});

counters[arm][(int) sports][(int) politics]++;
}

System.out.println("Max reward is " + maxReward + " min is " + minReward);

for (int j = 0; j < 2; j++) {
for (int k = 0; k < 2; k++) {
System.out.print("Chosen arm counts for context (" + j + ", " + k + "): ");
System.out.println(counters[0][j][k] + ", " + counters[1][j][k] + ", " + counters[2][j][k] + ", ");
}
}
}

private static double getMean(int arm, double[] context) {
double sportsCoef;
double politicsCoef;
double armBaseline;

if (arm == 0) {
sportsCoef = 0.25;
politicsCoef = 0.05;
armBaseline = 0.025;
} else if (arm == 1) {
sportsCoef = 0.05;
politicsCoef = 0.025;
armBaseline = 0.05;
} else {
sportsCoef = 0.05;
politicsCoef = 0.2;
armBaseline = 0.075;
}

return armBaseline + context[0] * sportsCoef + context[1] * politicsCoef;
}

/**
* Return a pseudorandom, Gaussian distributed double with mean 0 and standard deviation 1 bounded in [-1, 1]
* @param random
*/
private static double nextBoundedGaussian(Random random){
double v = random.nextGaussian();
v = Math.min(4, v);
v = Math.max(-4, v);
return v / 4;
}

}
172 changes: 172 additions & 0 deletions src/de/thunfischtoast/HybridLinUCB.java
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
/* Copyright (C) 2018 Christian Römer
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU General Public License for more details.
You should have received a copy of the GNU General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Contact: https://github.com/thunfischtoast or christian.roemer[at]udo.edu
*/

package de.thunfischtoast;

import org.apache.commons.math3.linear.*;

/**
* This class implements a contextual bandit algorithm called LinUCB as proposed by Li, Langford and Schapire.
* This is the version with hybrid linear models.
*
* @inproceedings{li2010contextual,
* title={A contextual-bandit approach to personalized news article recommendation},
* author={Li, Lihong and Chu, Wei and Langford, John and Schapire, Robert E},
* booktitle={Proceedings of the 19th international conference on World wide web},
* pages={661--670},
* year={2010},
* organization={ACM}
* }
*
* @author Christian Römer
*/
public class HybridLinUCB extends LinUCB {

/** Linear regression parameters for shared model */
private RealMatrix beta_hat;

/** Context accumulators for shared model */
private RealMatrix[] B_a;
private RealMatrix A_0;

/** Reward accumulator for shared model */
private RealMatrix b_0;

/** Number of shared features */
private int k;

/**
* @param d number of non-shared features
* @param k number of shared features
* @param n number of arms
* @param alpha how many times the standard deviation of the expected payoff if added to the predicted payoff in the ridge regression
*/
public HybridLinUCB(int d, int k, int n, double alpha) {
super(d, n, alpha);

if(k <= 0)
throw new IllegalArgumentException("Number of features > 0. If there is are no shared features use @LinUCB");

this.k = k;

A_0 = MatrixUtils.createRealIdentityMatrix(k);
b_0 = new Array2DRowRealMatrix(k, 1);

beta_hat = MatrixUtils.inverse(A_0).multiply(b_0);

B_a = new RealMatrix[n];
for (int i = 0; i < n; i++) {
B_a[i] = new Array2DRowRealMatrix(d, k);
}
}

/**
* Receive a reward for the given context and arm. Update the regression parameters accordingly.
*/
public double[] getPayoffs(RealVector sharedContext, RealVector combinedContext){
return getPayoffs(sharedContext.append(combinedContext));
}

/**
* Receive a reward for the given context and arm. Update the regression parameters accordingly.
* The given context must be of form [sharedContext,nonSharedContext].
*/
@Override
public double[] getPayoffs(RealVector combinedContext) {
if(combinedContext.getDimension() != k + d)
throw new IllegalArgumentException("The given context must be of form [sharedContext,nonSharedContext]!");

RealVector sharedContext = combinedContext.getSubVector(0, k);
RealVector nonSharedContext = combinedContext.getSubVector(k, d);

double[] payoffs = new double[n];

RealMatrix x = new Array2DRowRealMatrix(nonSharedContext.toArray());
RealMatrix x_t = x.transpose();

RealMatrix z = new Array2DRowRealMatrix(sharedContext.toArray());
RealMatrix z_t = z.transpose();

RealMatrix A_0_inv = MatrixUtils.inverse(A_0);

for(int i = 0; i < n; i++){
RealMatrix first = z_t.multiply(A_0_inv).multiply(z);
RealMatrix second = z_t.multiply(A_a_inverse[i]).multiply(B_a[i].transpose()).multiply(A_a_inverse[i]).multiply(x).scalarMultiply(2);
RealMatrix third = x_t.multiply(A_a_inverse[i]).multiply(x);
RealMatrix fourth = x_t.multiply(A_a_inverse[i]).multiply(B_a[i]).multiply(A_0_inv).multiply(B_a[i].transpose()).multiply(A_a_inverse[i]).multiply(x);

RealMatrix s = first.subtract(second).add(third).add(fourth);

double firstElement = z_t.multiply(beta_hat).getEntry(0,0);
double secondElement = x_t.multiply(theta_hat_a[i]).getEntry(0,0);

if(firstElement != 0)
payoffs[i] = firstElement + secondElement + (alpha * Math.sqrt(Math.abs(s.getEntry(0,0))));
else
payoffs[i] = firstElement + secondElement;
}

return payoffs;
}

/**
* Receive multiple rewards for the given contexts and arms. Update the regression parameters accordingly.
* The given contexts must be of form [sharedContext,nonSharedContext].
*/
@Override
public void receiveRewards(RealVector[] combinedContexts, int[] arm, double[] reward) {
for (int i = 0; i < combinedContexts.length; i++) {
RealVector combinedContext = combinedContexts[i];
if(combinedContext.getDimension() != k + d)
throw new IllegalArgumentException("The given context must be of form [sharedContext,nonSharedContext]!");

RealVector sharedContext = combinedContext.getSubVector(0, k);
RealVector nonSharedContext = combinedContext.getSubVector(k, d);

RealMatrix sharedContextMatrix = new Array2DRowRealMatrix(sharedContext.toArray());
RealMatrix sharedContextMatrixTranspose = sharedContextMatrix.transpose();
RealMatrix nonSharedContextMatrix = new Array2DRowRealMatrix(nonSharedContext.toArray());
RealMatrix nonSharedContextMatrixTranspose = nonSharedContextMatrix.transpose();

RealMatrix zMultz_t = sharedContextMatrix.multiply(sharedContextMatrixTranspose);
RealMatrix xMultx_t = nonSharedContextMatrix.multiply(nonSharedContextMatrixTranspose);
RealMatrix xMultz_t = nonSharedContextMatrix.multiply(sharedContextMatrixTranspose);

A_0 = A_0.add(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].transpose()).multiply(B_a[arm[i]]));

b_0 = b_0.add(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].transpose()).multiply(b_a[arm[i]].transpose()));

A_a[arm[i]] = A_a[arm[i]].add(xMultx_t); // update A[arm] by adding x_t[arm]*x_t[arm]^transposed to it
B_a[arm[i]] = B_a[arm[i]].add(xMultz_t);

double[] rMultx = nonSharedContext.mapMultiply(reward[i]).toArray();
double[] rMultz = sharedContext.mapMultiply(reward[i]).toArray();
b_a[arm[i]] = b_a[arm[i]].add(new Array2DRowRealMatrix(rMultx).transpose()); // update b[arm] by adding r_t * x_t[arm] to it

A_0 = A_0.add(zMultz_t).subtract(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].multiply(B_a[arm[i]])));
b_0 = b_0.add(new Array2DRowRealMatrix(rMultz)).subtract(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].multiply(b_a[arm[i]].transpose())));

for (int j = 0; j < A_a.length; j++) {
A_a_inverse[j] = MatrixUtils.inverse(A_a[j]);
theta_hat_a[j] = A_a_inverse[j].multiply(b_a[j].transpose());
}

beta_hat = MatrixUtils.inverse(A_0).multiply(b_0);
}
}
}
Loading

0 comments on commit e3018d5

Please sign in to comment.