-
Notifications
You must be signed in to change notification settings - Fork 11
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
9a33dec
commit e3018d5
Showing
4 changed files
with
515 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
<?xml version="1.0" encoding="UTF-8"?> | ||
<project xmlns="http://maven.apache.org/POM/4.0.0" | ||
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" | ||
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/xsd/maven-4.0.0.xsd"> | ||
<modelVersion>4.0.0</modelVersion> | ||
|
||
<groupId>thunfischtoast</groupId> | ||
<artifactId>linucb</artifactId> | ||
<version>1.0</version> | ||
|
||
<properties> | ||
<maven.compiler.source>1.8</maven.compiler.source> | ||
<maven.compiler.target>1.8</maven.compiler.target> | ||
</properties> | ||
|
||
<dependencies> | ||
<dependency> | ||
<groupId>org.apache.commons</groupId> | ||
<artifactId>commons-math3</artifactId> | ||
<version>3.6.1</version> | ||
</dependency> | ||
</dependencies> | ||
|
||
</project> |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,118 @@ | ||
/* Copyright (C) 2018 Christian Römer | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
Contact: https://github.com/thunfischtoast or christian.roemer[at]udo.edu | ||
*/ | ||
|
||
package de.thunfischtoast; | ||
|
||
import org.apache.commons.math3.linear.ArrayRealVector; | ||
import org.apache.commons.math3.linear.RealVector; | ||
|
||
import java.util.Random; | ||
|
||
/** | ||
* Small test class as inspired by John Maxwell (http://john-maxwell.com/post/2017-03-17/). We create a context of two | ||
* features that represent reading preferences of imaginary news site visitors. The visitors like or dislike sites with | ||
* sports content (context[0] = 1 or 0) and like or dislike sites with politics content (context[1] = 1 or 0). The | ||
* features are bound to integers for easier analysis. The bandit should offer one of three sites to the visitor, each | ||
* having different contents to offer for sports and politics. | ||
* | ||
* The algorithm implementation has problems when all context features are 0, as the expected reward will always become 0 | ||
* as well. It is currently not apparent if this is a problem of the implementation of the algorithm itself. | ||
*/ | ||
public class BanditTest { | ||
|
||
public static void main(String[] args) { | ||
HybridLinUCB linUCB = new HybridLinUCB(2, 2, 3, 18); | ||
// LinUCB linUCB = new LinUCB(2, 3, 5); | ||
|
||
Random random = new Random(7); | ||
|
||
for (int j = 0; j < 2; j++) { | ||
for (int k = 0; k < 2; k++) { | ||
for (int i = 0; i < 3; i++) { | ||
System.out.println("Arm " + i + ", Context (" + j + ", " + k + ") mean is " + getMean(i, new double[]{j, k})); | ||
} | ||
} | ||
} | ||
|
||
double maxReward = 0; | ||
double minReward = 1000; | ||
int[][][] counters = new int[3][2][2]; | ||
|
||
for (int i = 0; i < 10000; i++) { | ||
double sports = random.nextInt(2); | ||
double politics = random.nextInt(2); | ||
|
||
ArrayRealVector context = new ArrayRealVector(new double[]{sports, politics}); | ||
if(linUCB instanceof HybridLinUCB) | ||
context = context.append(context); | ||
|
||
int arm = linUCB.chooseArm(context); | ||
|
||
// make sure that rewards are between 0 and 1 | ||
double reward = ((nextBoundedGaussian(random) + getMean(arm, context.toArray())) + 1 ) / 2.25; | ||
maxReward = Math.max(maxReward, reward); | ||
minReward = Math.min(minReward, reward); | ||
linUCB.receiveRewards(new RealVector[]{context}, new int[]{arm}, new double[]{reward}); | ||
|
||
counters[arm][(int) sports][(int) politics]++; | ||
} | ||
|
||
System.out.println("Max reward is " + maxReward + " min is " + minReward); | ||
|
||
for (int j = 0; j < 2; j++) { | ||
for (int k = 0; k < 2; k++) { | ||
System.out.print("Chosen arm counts for context (" + j + ", " + k + "): "); | ||
System.out.println(counters[0][j][k] + ", " + counters[1][j][k] + ", " + counters[2][j][k] + ", "); | ||
} | ||
} | ||
} | ||
|
||
private static double getMean(int arm, double[] context) { | ||
double sportsCoef; | ||
double politicsCoef; | ||
double armBaseline; | ||
|
||
if (arm == 0) { | ||
sportsCoef = 0.25; | ||
politicsCoef = 0.05; | ||
armBaseline = 0.025; | ||
} else if (arm == 1) { | ||
sportsCoef = 0.05; | ||
politicsCoef = 0.025; | ||
armBaseline = 0.05; | ||
} else { | ||
sportsCoef = 0.05; | ||
politicsCoef = 0.2; | ||
armBaseline = 0.075; | ||
} | ||
|
||
return armBaseline + context[0] * sportsCoef + context[1] * politicsCoef; | ||
} | ||
|
||
/** | ||
* Return a pseudorandom, Gaussian distributed double with mean 0 and standard deviation 1 bounded in [-1, 1] | ||
* @param random | ||
*/ | ||
private static double nextBoundedGaussian(Random random){ | ||
double v = random.nextGaussian(); | ||
v = Math.min(4, v); | ||
v = Math.max(-4, v); | ||
return v / 4; | ||
} | ||
|
||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
/* Copyright (C) 2018 Christian Römer | ||
This program is free software: you can redistribute it and/or modify | ||
it under the terms of the GNU General Public License as published by | ||
the Free Software Foundation, either version 3 of the License, or | ||
(at your option) any later version. | ||
This program is distributed in the hope that it will be useful, | ||
but WITHOUT ANY WARRANTY; without even the implied warranty of | ||
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the | ||
GNU General Public License for more details. | ||
You should have received a copy of the GNU General Public License | ||
along with this program. If not, see <https://www.gnu.org/licenses/>. | ||
Contact: https://github.com/thunfischtoast or christian.roemer[at]udo.edu | ||
*/ | ||
|
||
package de.thunfischtoast; | ||
|
||
import org.apache.commons.math3.linear.*; | ||
|
||
/** | ||
* This class implements a contextual bandit algorithm called LinUCB as proposed by Li, Langford and Schapire. | ||
* This is the version with hybrid linear models. | ||
* | ||
* @inproceedings{li2010contextual, | ||
* title={A contextual-bandit approach to personalized news article recommendation}, | ||
* author={Li, Lihong and Chu, Wei and Langford, John and Schapire, Robert E}, | ||
* booktitle={Proceedings of the 19th international conference on World wide web}, | ||
* pages={661--670}, | ||
* year={2010}, | ||
* organization={ACM} | ||
* } | ||
* | ||
* @author Christian Römer | ||
*/ | ||
public class HybridLinUCB extends LinUCB { | ||
|
||
/** Linear regression parameters for shared model */ | ||
private RealMatrix beta_hat; | ||
|
||
/** Context accumulators for shared model */ | ||
private RealMatrix[] B_a; | ||
private RealMatrix A_0; | ||
|
||
/** Reward accumulator for shared model */ | ||
private RealMatrix b_0; | ||
|
||
/** Number of shared features */ | ||
private int k; | ||
|
||
/** | ||
* @param d number of non-shared features | ||
* @param k number of shared features | ||
* @param n number of arms | ||
* @param alpha how many times the standard deviation of the expected payoff if added to the predicted payoff in the ridge regression | ||
*/ | ||
public HybridLinUCB(int d, int k, int n, double alpha) { | ||
super(d, n, alpha); | ||
|
||
if(k <= 0) | ||
throw new IllegalArgumentException("Number of features > 0. If there is are no shared features use @LinUCB"); | ||
|
||
this.k = k; | ||
|
||
A_0 = MatrixUtils.createRealIdentityMatrix(k); | ||
b_0 = new Array2DRowRealMatrix(k, 1); | ||
|
||
beta_hat = MatrixUtils.inverse(A_0).multiply(b_0); | ||
|
||
B_a = new RealMatrix[n]; | ||
for (int i = 0; i < n; i++) { | ||
B_a[i] = new Array2DRowRealMatrix(d, k); | ||
} | ||
} | ||
|
||
/** | ||
* Receive a reward for the given context and arm. Update the regression parameters accordingly. | ||
*/ | ||
public double[] getPayoffs(RealVector sharedContext, RealVector combinedContext){ | ||
return getPayoffs(sharedContext.append(combinedContext)); | ||
} | ||
|
||
/** | ||
* Receive a reward for the given context and arm. Update the regression parameters accordingly. | ||
* The given context must be of form [sharedContext,nonSharedContext]. | ||
*/ | ||
@Override | ||
public double[] getPayoffs(RealVector combinedContext) { | ||
if(combinedContext.getDimension() != k + d) | ||
throw new IllegalArgumentException("The given context must be of form [sharedContext,nonSharedContext]!"); | ||
|
||
RealVector sharedContext = combinedContext.getSubVector(0, k); | ||
RealVector nonSharedContext = combinedContext.getSubVector(k, d); | ||
|
||
double[] payoffs = new double[n]; | ||
|
||
RealMatrix x = new Array2DRowRealMatrix(nonSharedContext.toArray()); | ||
RealMatrix x_t = x.transpose(); | ||
|
||
RealMatrix z = new Array2DRowRealMatrix(sharedContext.toArray()); | ||
RealMatrix z_t = z.transpose(); | ||
|
||
RealMatrix A_0_inv = MatrixUtils.inverse(A_0); | ||
|
||
for(int i = 0; i < n; i++){ | ||
RealMatrix first = z_t.multiply(A_0_inv).multiply(z); | ||
RealMatrix second = z_t.multiply(A_a_inverse[i]).multiply(B_a[i].transpose()).multiply(A_a_inverse[i]).multiply(x).scalarMultiply(2); | ||
RealMatrix third = x_t.multiply(A_a_inverse[i]).multiply(x); | ||
RealMatrix fourth = x_t.multiply(A_a_inverse[i]).multiply(B_a[i]).multiply(A_0_inv).multiply(B_a[i].transpose()).multiply(A_a_inverse[i]).multiply(x); | ||
|
||
RealMatrix s = first.subtract(second).add(third).add(fourth); | ||
|
||
double firstElement = z_t.multiply(beta_hat).getEntry(0,0); | ||
double secondElement = x_t.multiply(theta_hat_a[i]).getEntry(0,0); | ||
|
||
if(firstElement != 0) | ||
payoffs[i] = firstElement + secondElement + (alpha * Math.sqrt(Math.abs(s.getEntry(0,0)))); | ||
else | ||
payoffs[i] = firstElement + secondElement; | ||
} | ||
|
||
return payoffs; | ||
} | ||
|
||
/** | ||
* Receive multiple rewards for the given contexts and arms. Update the regression parameters accordingly. | ||
* The given contexts must be of form [sharedContext,nonSharedContext]. | ||
*/ | ||
@Override | ||
public void receiveRewards(RealVector[] combinedContexts, int[] arm, double[] reward) { | ||
for (int i = 0; i < combinedContexts.length; i++) { | ||
RealVector combinedContext = combinedContexts[i]; | ||
if(combinedContext.getDimension() != k + d) | ||
throw new IllegalArgumentException("The given context must be of form [sharedContext,nonSharedContext]!"); | ||
|
||
RealVector sharedContext = combinedContext.getSubVector(0, k); | ||
RealVector nonSharedContext = combinedContext.getSubVector(k, d); | ||
|
||
RealMatrix sharedContextMatrix = new Array2DRowRealMatrix(sharedContext.toArray()); | ||
RealMatrix sharedContextMatrixTranspose = sharedContextMatrix.transpose(); | ||
RealMatrix nonSharedContextMatrix = new Array2DRowRealMatrix(nonSharedContext.toArray()); | ||
RealMatrix nonSharedContextMatrixTranspose = nonSharedContextMatrix.transpose(); | ||
|
||
RealMatrix zMultz_t = sharedContextMatrix.multiply(sharedContextMatrixTranspose); | ||
RealMatrix xMultx_t = nonSharedContextMatrix.multiply(nonSharedContextMatrixTranspose); | ||
RealMatrix xMultz_t = nonSharedContextMatrix.multiply(sharedContextMatrixTranspose); | ||
|
||
A_0 = A_0.add(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].transpose()).multiply(B_a[arm[i]])); | ||
|
||
b_0 = b_0.add(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].transpose()).multiply(b_a[arm[i]].transpose())); | ||
|
||
A_a[arm[i]] = A_a[arm[i]].add(xMultx_t); // update A[arm] by adding x_t[arm]*x_t[arm]^transposed to it | ||
B_a[arm[i]] = B_a[arm[i]].add(xMultz_t); | ||
|
||
double[] rMultx = nonSharedContext.mapMultiply(reward[i]).toArray(); | ||
double[] rMultz = sharedContext.mapMultiply(reward[i]).toArray(); | ||
b_a[arm[i]] = b_a[arm[i]].add(new Array2DRowRealMatrix(rMultx).transpose()); // update b[arm] by adding r_t * x_t[arm] to it | ||
|
||
A_0 = A_0.add(zMultz_t).subtract(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].multiply(B_a[arm[i]]))); | ||
b_0 = b_0.add(new Array2DRowRealMatrix(rMultz)).subtract(B_a[arm[i]].transpose().multiply(A_a_inverse[arm[i]].multiply(b_a[arm[i]].transpose()))); | ||
|
||
for (int j = 0; j < A_a.length; j++) { | ||
A_a_inverse[j] = MatrixUtils.inverse(A_a[j]); | ||
theta_hat_a[j] = A_a_inverse[j].multiply(b_a[j].transpose()); | ||
} | ||
|
||
beta_hat = MatrixUtils.inverse(A_0).multiply(b_0); | ||
} | ||
} | ||
} |
Oops, something went wrong.