From eba946cab4d3d745fe448610802fb6538c51a98a Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adam=20Gudy=C5=9B?= Date: Tue, 19 Dec 2023 12:26:14 +0100 Subject: [PATCH] Multiple speed improvements * Classification: * Significantly faster growing (two orders of magnitude for sets with >100k instances), faster pruning, * Added approximate mode (`approximate_induction` parameter). Regression: * Mean-based growing set as default (few times faster then median, non-significant impact on accuracy). Survival: * Faster growing and pruning (few fold improvement). --- adaa.analytics.rules/build.gradle | 2 +- .../rules/consoles/ExperimentalConsole.java | 2 +- .../rules/logic/induction/AbstractFinder.java | 60 ++- .../rules/logic/induction/ActionFinder.java | 4 +- .../ApproximateClassificationFinder.java | 96 +++-- .../induction/ClassificationExpertFinder.java | 10 +- .../induction/ClassificationExpertSnC.java | 10 +- .../logic/induction/ClassificationFinder.java | 108 +++--- .../logic/induction/ClassificationSnC.java | 8 +- .../induction/ContrastRegressionFinder.java | 1 + .../logic/induction/InductionParameters.java | 9 +- .../induction/RegressionExpertFinder.java | 68 +++- .../logic/induction/RegressionExpertSnC.java | 43 ++- .../logic/induction/RegressionFinder.java | 25 +- .../rules/logic/induction/RegressionSnC.java | 30 +- .../SurvivalLogRankExpertFinder.java | 34 +- .../induction/SurvivalLogRankFinder.java | 126 +++++- .../rules/logic/quality/LogRank.java | 94 +++-- .../logic/representation/ConditionBase.java | 9 + .../logic/representation/IntegerBitSet.java | 37 +- .../representation/KaplanMeierEstimator.java | 360 ++++++++---------- .../logic/representation/RegressionRule.java | 1 + .../representation/SortedExampleSetEx.java | 10 + .../rules/operator/RuleGenerator.java | 6 + .../config/RegressionExpertSnCTest.xml | 4 + .../resources/config/RegressionSnCTest.xml | 1 + .../testcases/InductionParametersFactory.java | 3 + 27 files changed, 744 insertions(+), 417 deletions(-) diff --git a/adaa.analytics.rules/build.gradle b/adaa.analytics.rules/build.gradle index 6e93e19f..ed6b3986 100644 --- a/adaa.analytics.rules/build.gradle +++ b/adaa.analytics.rules/build.gradle @@ -27,7 +27,7 @@ codeQuality { } sourceCompatibility = 1.8 -version = '1.6.2' +version = '1.7.0' jar { diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/consoles/ExperimentalConsole.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/consoles/ExperimentalConsole.java index fd63c4a3..640fc93f 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/consoles/ExperimentalConsole.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/consoles/ExperimentalConsole.java @@ -150,7 +150,7 @@ private void parse(String[] args) { RapidMiner.setExecutionMode(RapidMiner.ExecutionMode.COMMAND_LINE); RapidMiner.init(); - //System.in.read(); + // System.in.read(); execute(argList.get(0)); } else { diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/AbstractFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/AbstractFinder.java index 8b5260df..e7c43076 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/AbstractFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/AbstractFinder.java @@ -81,8 +81,9 @@ public void close() { /** * Can be implemented by subclasses to perform some initial processing prior growing. * @param trainSet Training set. + * @return Preprocessed training set. */ - public void preprocess(ExampleSet trainSet) {} + public ExampleSet preprocess(ExampleSet trainSet) { return trainSet; } /** * Adds elementary conditions to the rule premise until termination conditions are fulfilled. @@ -104,11 +105,14 @@ public int grow( int initialConditionsCount = rule.getPremise().getSubconditions().size(); // get current covering - Covering covering = new Covering(); - rule.covers(dataset, covering, covering.positives, covering.negatives); - Set covered = new HashSet(); - covered.addAll(covering.positives); - covered.addAll(covering.negatives); + ContingencyTable contingencyTable = new Covering(); + IntegerBitSet positives = new IntegerBitSet(dataset.size()); + IntegerBitSet negatives = new IntegerBitSet(dataset.size()); + rule.covers(dataset, contingencyTable, positives, negatives); + //Set covered = new HashSet(); + IntegerBitSet covered = new IntegerBitSet(dataset.size()); + covered.addAll(positives); + covered.addAll(negatives); Set allowedAttributes = new TreeSet(new AttributeComparator()); for (Attribute a: dataset.getAttributes()) { allowedAttributes.add(a); @@ -126,18 +130,23 @@ public int grow( notifyConditionAdded(condition); - covering = new Covering(); - rule.covers(dataset, covering, covering.positives, covering.negatives); - covered.clear(); - covered.addAll(covering.positives); - covered.addAll(covering.negatives); + //recalculate covering only when needed + if (condition.getCovering() != null) { + positives.retainAll(condition.getCovering()); + negatives.retainAll(condition.getCovering()); + covered.retainAll(condition.getCovering()); + } else { + contingencyTable.clear(); + positives.clear(); + negatives.clear(); + + rule.covers(dataset, contingencyTable, positives, negatives); + covered.clear(); + covered.addAll(positives); + covered.addAll(negatives); + } - rule.setCoveringInformation(covering); - rule.getCoveredPositives().setAll(covering.positives); - rule.getCoveredNegatives().setAll(covering.negatives); - rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); - Logger.log("Condition " + rule.getPremise().getSubconditions().size() + " added: " + rule.toString() + ", weight=" + rule.getWeight() + "\n", Level.FINER); @@ -152,12 +161,25 @@ public int grow( carryOn = false; } - } while (carryOn); - + } while (carryOn); + + // ugly + Covering covering = new Covering(); + covering.positives = positives; + covering.negatives = negatives; + + rule.setCoveringInformation(covering); + rule.getCoveredPositives().setAll(positives); + rule.getCoveredNegatives().setAll(negatives); + // if rule has been successfully grown int addedConditionsCount = rule.getPremise().getSubconditions().size() - initialConditionsCount; - rule.setInducedContitionsCount(addedConditionsCount); + if (addedConditionsCount > 0) { + rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); + } + + rule.setInducedContitionsCount(addedConditionsCount); notifyGrowingFinished(rule); return addedConditionsCount; diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ActionFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ActionFinder.java index 861d8de7..45418c94 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ActionFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ActionFinder.java @@ -30,8 +30,8 @@ public ActionFinder(ActionInductionParameters params) { classificationFinder = new ClassificationFinder(params); } - public void preprocess(ExampleSet trainSet) { - classificationFinder.preprocess(trainSet); + public ExampleSet preprocess(ExampleSet trainSet) { + return classificationFinder.preprocess(trainSet); } private void log(String msg, Level level) { diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ApproximateClassificationFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ApproximateClassificationFinder.java index af211aa6..a2a3e81f 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ApproximateClassificationFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ApproximateClassificationFinder.java @@ -32,8 +32,6 @@ public ConditionCandidate(String attribute, IValueSet valueSet) { } } - protected static final int MAX_BINS = 100; - // Example description: // [0-31] - example id (32 bits) // [32-47] - block id (16 bits) @@ -73,7 +71,7 @@ public ApproximateClassificationFinder(InductionParameters params) { } @Override - public void preprocess(ExampleSet dataset) { + public ExampleSet preprocess(ExampleSet dataset) { int n_examples = dataset.size(); int n_attributes = dataset.getAttributes().size(); @@ -81,19 +79,28 @@ public void preprocess(ExampleSet dataset) { descriptions = new long[n_attributes][n_examples]; mappings = new int[n_attributes][n_examples]; - bins_positives = new int[n_attributes][MAX_BINS]; - bins_negatives = new int[n_attributes][MAX_BINS]; - bins_newPositives = new int[n_attributes][MAX_BINS]; - bins_begins = new int[n_attributes][MAX_BINS]; + bins_positives = new int[n_attributes][]; + bins_negatives = new int[n_attributes][]; + bins_newPositives = new int[n_attributes][]; + bins_begins = new int[n_attributes][]; ruleRanges = new int[n_attributes][2]; for (Attribute attr: dataset.getAttributes()) { int ia = attr.getTableIndex(); + int n_vals = attr.isNominal() ? attr.getMapping().size() : params.getApproximateBinsCount(); + + bins_positives[ia] = new int [n_vals]; + bins_negatives[ia] = new int[n_vals]; + bins_newPositives[ia] = new int[n_vals]; + bins_begins[ia] = new int[n_vals]; determineBins(dataset, attr, descriptions[ia], mappings[ia], bins_begins[ia], ruleRanges[ia]); + arrayCopies.put("ruleRanges", (Object)Arrays.stream(ruleRanges).map(int[]::clone).toArray(int[][]::new)); } + + return dataset; } /** @@ -293,13 +300,14 @@ protected ElementaryCondition induceCondition( int covered_n = 0; int covered_new_p = 0; - // use first attribute to establish number of covered elements + // use first attribute to establish number of covered elements for (int bid = ruleRanges[0][0]; bid < ruleRanges[0][1]; ++bid) { covered_p += bins_positives[0][bid]; covered_n += bins_negatives[0][bid]; covered_new_p += bins_newPositives[0][bid]; } + // iterate over all allowed decision attributes for (Attribute attr : dataset.getAttributes()) { @@ -462,7 +470,10 @@ class Stats { if (current != null && current.getAttribute() != null) { Logger.log("\tAttribute best: " + current + ", quality=" + current.quality, Level.FINEST); - updateMidpoint(dataset, current); + Attribute attr = dataset.getAttributes().get(current.getAttribute()); + if (attr.isNumerical()) { + updateMidpoint(dataset, current); + } Logger.log(", adjusted: " + current + "\n", Level.FINEST); } @@ -482,13 +493,13 @@ class Stats { return null; // empty condition - discard } - updateMidpoint(dataset, best); - - Logger.log("\tFinal best: " + best + ", quality=" + best.quality + "\n", Level.FINEST); - - if (bestAttr.isNominal()) { + if (bestAttr.isNumerical()) { + updateMidpoint(dataset, best); + } else { allowedAttributes.remove(bestAttr); } + + Logger.log("\tFinal best: " + best + ", quality=" + best.quality + "\n", Level.FINEST); } return best; @@ -508,7 +519,7 @@ protected void notifyConditionAdded(ConditionBase cnd) { ruleRanges[aid][0] = blockId + 1; ruleRanges[aid][1] = blockId; } else { - excludeExamplesFromArrays(trainSet, attr, ruleRanges[aid][0], candidate.blockId + 1); + excludeExamplesFromArrays(trainSet, attr, ruleRanges[aid][0], candidate.blockId); excludeExamplesFromArrays(trainSet, attr, candidate.blockId + 1, ruleRanges[aid][1]); ruleRanges[aid][0] = blockId; ruleRanges[aid][1] = blockId + 1; @@ -546,6 +557,7 @@ protected void determineBins(ExampleSet dataset, Attribute attr, vals[i] = dataset.getExample(i).getValue(attr); } + /* class ValuesComparator implements IntComparator { double [] vals; @@ -597,12 +609,12 @@ public int compare(Bin p, Bin q) { } } - PriorityQueue bins = new PriorityQueue(100, new SizeBinComparator()); - PriorityQueue finalBins = new PriorityQueue(100, new IndexBinComparator()); + PriorityQueue bins = new PriorityQueue(binsBegins.length, new SizeBinComparator()); + PriorityQueue finalBins = new PriorityQueue(binsBegins.length, new IndexBinComparator()); bins.add(new Bin(0, mappings.length)); - while (bins.size() > 0 && (bins.size() + finalBins.size()) < MAX_BINS) { + while (bins.size() > 0 && (bins.size() + finalBins.size()) < binsBegins.length) { Bin b = bins.poll(); int id = (b.end + b.begin) / 2; @@ -611,9 +623,13 @@ public int compare(Bin p, Bin q) { // decide direction if (vals[b.begin] == midval) { // go up - while (vals[id] == midval) { ++id; } + while (vals[id] == midval) { + ++id; + } } else { - while (vals[id - 1] == midval) { --id; } + while (vals[id - 1] == midval) { + --id; + } } Bin leftBin = new Bin(b.begin, id); @@ -646,17 +662,16 @@ public int compare(Bin p, Bin q) { descriptions[i] |= bid << OFFSET_BIN; } - binsBegins[(int)bid] = b.begin; + binsBegins[(int) bid] = b.begin; ++bid; } ruleRanges[0] = 0; - ruleRanges[1] = (int)bid; - - // print bins - for (int i = 0; i < bid; ++i) { + ruleRanges[1] = (int) bid; + // print bins + for (int i = 0; i < ruleRanges[1]; ++i) { int lo = binsBegins[i]; - int hi = (i == bid - 1) ? trainSet.size() : binsBegins[i+1] - 1; + int hi = (i == ruleRanges[1] - 1) ? trainSet.size() : binsBegins[i+1] - 1; Logger.log("[" + lo + ", " + hi + "]:" + vals[lo] + "\n", Level.FINER); } } @@ -665,6 +680,10 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int Logger.log("Excluding examples: " + attr.getName() + " from [" + binLo + "," + binHi + "]\n", Level.FINER); + if (binLo == binHi) { + return; + } + int n_examples = dataset.size(); int src_row = attr.getTableIndex(); long[] src_descriptions = descriptions[src_row]; @@ -695,9 +714,11 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int int dst_row = other.getTableIndex(); // if nominal attribute was already used + /* if (other.isNominal() && Math.abs(ruleRanges[dst_row][1] - ruleRanges[dst_row][0]) == 1) { continue; } + */ Future future = pool.submit(() -> { @@ -717,8 +738,14 @@ protected void excludeExamplesFromArrays(ExampleSet dataset, Attribute attr, int int bid = (int) ((desc & MASK_BIN) >> OFFSET_BIN); + boolean opposite = dst_ranges[0] > dst_ranges[1]; // this indicate nominal opposite condition + int dst_bin_lo = Math.min(dst_ranges[0], dst_ranges[1]); + int dst_bin_hi = Math.max(dst_ranges[0], dst_ranges[1]); + // update stats only in bins covered by the rule - if (bid >= dst_ranges[0] && bid < dst_ranges[1] && ((desc & FLAG_COVERED) != 0)) { + boolean in_range = (bid >= dst_bin_lo && bid < dst_bin_hi) || (opposite && (bid < dst_bin_lo || bid >= dst_bin_hi)); + + if (in_range && ((desc & FLAG_COVERED) != 0)) { if ((desc & FLAG_POSITIVE) != 0) { --dst_positives[bid]; @@ -755,12 +782,16 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) { int n_examples = dataset.size(); + int[][] copy_ranges = (int[][])arrayCopies.get("ruleRanges"); + for (Attribute attr: dataset.getAttributes()) { int attribute_id = attr.getTableIndex(); Arrays.fill(bins_positives[attribute_id], 0); Arrays.fill(bins_negatives[attribute_id], 0); Arrays.fill(bins_newPositives[attribute_id], 0); + ruleRanges[attribute_id][0] = 0; + ruleRanges[attribute_id][1] = copy_ranges[attribute_id][1]; long[] descriptions_row = descriptions[attribute_id]; int[] mappings_row = mappings[attribute_id]; @@ -792,6 +823,9 @@ protected void resetArrays(ExampleSet dataset, int targetLabel) { } } + // reset rule ranges + + Logger.log("Reset arrays for class " + targetLabel + "\n", Level.FINER); printArrays(); @@ -816,9 +850,13 @@ protected void printArrays() { int bin_p = 0, bin_n = 0, bin_new_p = 0, bin_outside = 0; - for (int i = 0; i < MAX_BINS; ++i) { + boolean opposite = ruleRanges[attribute_id][0] > ruleRanges[attribute_id][1]; // this indicate nominal opposite condition + int lo = Math.min(ruleRanges[attribute_id][0], ruleRanges[attribute_id][1]); + int hi = Math.max(ruleRanges[attribute_id][0], ruleRanges[attribute_id][1]); + + for (int i = 0; i < bins_positives[attribute_id].length; ++i) { - if (i >= ruleRanges[attribute_id][0] && i < ruleRanges[attribute_id][1]) { + if ((i >= lo && i < hi) || (opposite && (i < lo || i >= hi)) ) { bin_p += bins_positives[attribute_id][i]; bin_n += bins_negatives[attribute_id][i]; bin_new_p += bins_newPositives[attribute_id][i]; diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertFinder.java index d3108be4..d8d1752e 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertFinder.java @@ -111,7 +111,7 @@ public void adjust( newCondition = induceCondition( rule, dataset, mustBeCovered, covered, attr); newCondition.setType(Type.FORCED); - tryAddCondition(rule, null, newCondition, dataset, covered, uncoveredPositives, conditionCovered); + tryAddCondition(rule, null, newCondition, dataset, covered, uncoveredPositives); } else { // add condition as it is without verification @@ -265,7 +265,7 @@ public int grow( } if (bestCondition != null) { - carryOn = tryAddCondition(rule, null, bestCondition, dataset, covered,uncoveredPositives, conditionCovered); + carryOn = tryAddCondition(rule, null, bestCondition, dataset, covered, uncoveredPositives); knowledge.getPreferredConditions((int)classId).remove(bestCondition); newlyCoveredPositives.retainAll(rule.getCoveredPositives()); @@ -298,7 +298,7 @@ public int grow( do { ElementaryCondition condition = induceCondition(rule, dataset, uncoveredPositives, covered, localAllowed, rule.getCoveredPositives()); - carryOn = tryAddCondition(rule, null, condition, dataset, covered,uncoveredPositives, conditionCovered); + carryOn = tryAddCondition(rule, null, condition, dataset, covered,uncoveredPositives); // fixme: we are not sure if condition was added if (carryOn) { knowledge.getPreferredAttributes((int)classId).remove(condition.getAttribute()); @@ -335,9 +335,9 @@ public int grow( rule, dataset, uncoveredPositives, covered, allowedAttributes, rule.getCoveredPositives()); if (params.getSelectBestCandidate()) { - carryOn = tryAddCondition(currentRule, rule, condition, dataset, covered, uncoveredPositives, conditionCovered); + carryOn = tryAddCondition(currentRule, rule, condition, dataset, covered, uncoveredPositives); } else { - carryOn = tryAddCondition(rule, null, condition, dataset, covered, uncoveredPositives, conditionCovered); + carryOn = tryAddCondition(rule, null, condition, dataset, covered, uncoveredPositives); } } while (carryOn); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertSnC.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertSnC.java index fb45fac4..5826a1cd 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertSnC.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationExpertSnC.java @@ -78,9 +78,9 @@ public ClassificationRuleSet run(ExampleSet dataset) // iterate over all classes for (int classId = 0; classId < mapping.size(); ++classId) { - Set positives = new IntegerBitSet(dataset.size()); - Set negatives = new IntegerBitSet(dataset.size()); - Set uncoveredPositives = new IntegerBitSet(dataset.size()); + IntegerBitSet positives = new IntegerBitSet(dataset.size()); + IntegerBitSet negatives = new IntegerBitSet(dataset.size()); + IntegerBitSet uncoveredPositives = new IntegerBitSet(dataset.size()); Set uncovered = new HashSet(); double weighted_P = 0; double weighted_N = 0; @@ -145,6 +145,8 @@ public ClassificationRuleSet run(ExampleSet dataset) rule.setCoveredNegatives(new IntegerBitSet(dataset.size())); rule.getCoveredPositives().addAll(positives); rule.getCoveredNegatives().addAll(negatives); + + rule.getConsequence().setCovering(positives); ClassificationExpertFinder erf = (ClassificationExpertFinder)finder; @@ -194,6 +196,8 @@ public ClassificationRuleSet run(ExampleSet dataset) rule.setCoveredNegatives(new IntegerBitSet(dataset.size())); rule.getCoveredPositives().addAll(positives); rule.getCoveredNegatives().addAll(negatives); + + rule.getConsequence().setCovering(positives); ClassificationExpertFinder erf = (ClassificationExpertFinder)finder; erf.setKnowledge(classKnowledge); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationFinder.java index e578bdd7..453fcf19 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationFinder.java @@ -14,15 +14,11 @@ ******************************************************************************/ package adaa.analytics.rules.logic.induction; -import java.io.IOException; import java.util.*; import java.util.concurrent.ExecutionException; import java.util.concurrent.Future; import java.util.logging.Level; -import adaa.analytics.rules.logic.quality.ClassificationMeasure; -import adaa.analytics.rules.logic.quality.Hypergeometric; -import adaa.analytics.rules.logic.quality.IQualityMeasure; import adaa.analytics.rules.logic.representation.*; import com.rapidminer.example.Attribute; @@ -30,7 +26,6 @@ import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; import com.rapidminer.example.table.DataRow; -import com.rapidminer.tools.container.Pair; /** @@ -64,11 +59,11 @@ public ClassificationFinder(InductionParameters params) { * @param trainSet Training set. */ @Override - public void preprocess(ExampleSet trainSet) { + public ExampleSet preprocess(ExampleSet trainSet) { // do nothing for weighted datasets if (trainSet.getAttributes().getWeight() != null) { - return; + return trainSet; } precalculatedCoverings = new HashMap>(); @@ -126,6 +121,8 @@ public void preprocess(ExampleSet trainSet) { } catch (InterruptedException | ExecutionException e) { e.printStackTrace(); } + + return trainSet; } /** @@ -152,9 +149,6 @@ public int grow( covered.addAll(rule.getCoveredPositives()); covered.addAll(rule.getCoveredNegatives()); - // bit vectors for faster operations on coverings - IntegerBitSet conditionCovered = new IntegerBitSet(dataset.size()); - Set allowedAttributes = new TreeSet(new AttributeComparator()); for (Attribute a: dataset.getAttributes()) { allowedAttributes.add(a); @@ -180,7 +174,7 @@ public int grow( if (condition != null) { - carryOn = tryAddCondition(currentRule, bestRule, condition, dataset, covered, uncovered, conditionCovered); + carryOn = tryAddCondition(currentRule, bestRule, condition, dataset, covered, uncovered); if (params.getMaxGrowingConditions() > 0) { if (currentRule.getPremise().getSubconditions().size() - initialConditionsCount >= @@ -230,9 +224,9 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set int examplesCount = trainSet.size(); int conditionsCount = rule.getPremise().getSubconditions().size(); - int maskLength = (trainSet.size() + Long.SIZE - 1) / Long.SIZE; - long[] masks = new long[conditionsCount * maskLength]; - long[] labelMask = new long[maskLength]; + int maskLength = (trainSet.size() + Long.SIZE - 1) / Long.SIZE; + // long[] masks = new long[conditionsCount * maskLength]; + long[] labelMask = rule.getConsequence().getCovering().getRawTable(); long[] uncoveredMask = new long[maskLength]; double P = rule.getWeighted_P(); @@ -242,14 +236,28 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set int counter_p = 0; int counter_new_p = 0; + // copy coverings from conditions into masks array + for (int m = 0; m < conditionsCount; ++m) { + ConditionBase cnd = rule.getPremise().getSubconditions().get(m); + long[] conditionWords = cnd.getCovering().getRawTable(); + + // count conditions + for (int i = 0; i < trainSet.size(); ++i) { + int wordId = i >>> IntegerBitSet.ID_SHIFT; + int wordOffset = i & IntegerBitSet.OFFSET_MASK; + + if ((conditionWords[wordId] & (1L << wordOffset)) != 0) { + ++conditionsPerExample[i]; + } + } + } + for (int i = 0; i < trainSet.size(); ++i) { - Example e = trainSet.getExample(i); - int wordId = i / Long.SIZE; - int wordOffset = i % Long.SIZE; + int wordId = i >>> IntegerBitSet.ID_SHIFT; + int wordOffset = i & IntegerBitSet.OFFSET_MASK; // is positive - if (rule.getConsequence().evaluate(e)) { - labelMask[wordId] |= 1L << wordOffset; + if ((labelMask[wordId] & (1L << wordOffset)) != 0) { ++counter_p; // is uncovered @@ -258,23 +266,14 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set ++counter_new_p; } } - - for (int m = 0; m < conditionsCount; ++m) { - ConditionBase cnd = rule.getPremise().getSubconditions().get(m); - if (cnd.evaluate(e)) { - masks[m * maskLength + wordId] |= 1L << wordOffset; - ++conditionsPerExample[i]; - } - } } - IntegerBitSet removedConditions = new IntegerBitSet(conditionsCount); int conditionsLeft = rule.getPremise().getSubconditions().size(); - ContingencyTable ct = new ContingencyTable(); - rule.covers(trainSet, ct); - double initialQuality = params.getPruningMeasure().calculate(trainSet, ct); + //ContingencyTable ct = new ContingencyTable(); + //rule.covers(trainSet, ct); + double initialQuality = params.getPruningMeasure().calculate(trainSet, rule.getCoveringInformation()); initialQuality = modifier.modifyQuality(initialQuality, null, counter_p, counter_new_p); boolean continueClimbing = true; @@ -292,7 +291,7 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set for (int cid = 0; cid < conditionsCount; ++cid) { final int fcid = cid; Future f = pool.submit( () -> { - + ConditionBase cnd = (ConditionBase) rule.getPremise().getSubconditions().get(fcid); // ignore already removed conditions if (removedConditions.contains(fcid)) { @@ -306,7 +305,8 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set // only elementary conditions are prunable String attr = ((ElementaryCondition)cnd).getAttribute(); - + + long[] conditionWords = cnd.getCovering().getRawTable(); double p = 0; double n = 0; double new_p = 0; @@ -315,7 +315,7 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set int id = 0; for (int wordId = 0; wordId < maskLength; ++wordId) { - long word = masks[fcid * maskLength + wordId]; + long word = conditionWords[wordId]; long filteredWord = 0; for (int wordOffset = 0; wordOffset < Long.SIZE && id < examplesCount; ++wordOffset, ++id) { @@ -381,14 +381,16 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set initialQuality = bestQuality; removedConditions.add(toRemove); - notifyConditionRemoved(rule.getPremise().getSubconditions().get(toRemove)); + ConditionBase cndToRemove = rule.getPremise().getSubconditions().get(toRemove); + notifyConditionRemoved(cndToRemove); --conditionsLeft; // decrease counters for examples covered by removed condition int id = 0; + long[] localWords = cndToRemove.getCovering().getRawTable(); for (int wordId = 0; wordId < maskLength; ++wordId) { - long word = masks[toRemove * maskLength + wordId]; + long word = localWords[wordId]; for (int wordOffset = 0; wordOffset < Long.SIZE && id < examplesCount; ++wordOffset, ++id) { if ((word & (1L << wordOffset)) != 0) { @@ -406,23 +408,35 @@ public void prune(final Rule rule, final ExampleSet trainSet, final Set } CompoundCondition prunedPremise = new CompoundCondition(); + + IntegerBitSet positives = rule.getConsequence().getCovering().clone(); + IntegerBitSet negatives = new IntegerBitSet(trainSet.size()); + positives.negate(negatives); + + long[] positiveWords = positives.getRawTable(); + long[] negativeWords = negatives.getRawTable(); for (int cid = 0; cid < conditionsCount; ++cid) { if (!removedConditions.contains(cid)) { - prunedPremise.addSubcondition(rule.getPremise().getSubconditions().get(cid)); + ConditionBase cnd = rule.getPremise().getSubconditions().get(cid); + prunedPremise.addSubcondition(cnd); + + long [] words = cnd.getCovering().getRawTable(); + + for (int wordId = 0; wordId < maskLength; ++wordId) { + positiveWords[wordId] &= words[wordId]; + negativeWords[wordId] &= words[wordId]; + } } } rule.setPremise(prunedPremise); - ct = new ContingencyTable(); - IntegerBitSet positives = new IntegerBitSet(trainSet.size()); - IntegerBitSet negatives = new IntegerBitSet(trainSet.size()); - - rule.covers(trainSet, ct, positives, negatives); + //ct = new ContingencyTable(); + //rule.covers(trainSet, ct, positives, negatives); - rule.setWeighted_p(ct.weighted_p); - rule.setWeighted_n(ct.weighted_n); + rule.setWeighted_p(positives.size()); + rule.setWeighted_n(negatives.size()); rule.setCoveredPositives(positives); rule.setCoveredNegatives(negatives); @@ -768,16 +782,16 @@ public boolean tryAddCondition( final ConditionBase condition, final ExampleSet trainSet, final Set covered, - final Set uncovered, - final IntegerBitSet conditionCovered) { + final Set uncovered) { boolean carryOn = true; boolean add = false; ContingencyTable ct = new ContingencyTable(); if (condition != null) { - conditionCovered.clear(); + IntegerBitSet conditionCovered = new IntegerBitSet(trainSet.size()); condition.evaluate(trainSet, conditionCovered); + condition.setCovering(conditionCovered); // calculate quality before addition ct.weighted_P = currentRule.getWeighted_P(); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationSnC.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationSnC.java index cd6107c3..cd3783c3 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationSnC.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ClassificationSnC.java @@ -93,9 +93,9 @@ public RuleSetBase run(ExampleSet dataset) { ClassificationRuleSet ruleset = (ClassificationRuleSet) factory.create(dataset); - Set positives = new IntegerBitSet(dataset.size()); - Set negatives = new IntegerBitSet(dataset.size()); - Set uncoveredPositives = new IntegerBitSet(dataset.size()); + IntegerBitSet positives = new IntegerBitSet(dataset.size()); + IntegerBitSet negatives = new IntegerBitSet(dataset.size()); + IntegerBitSet uncoveredPositives = new IntegerBitSet(dataset.size()); Set uncovered = new HashSet(); double weighted_P = 0; @@ -141,6 +141,8 @@ public RuleSetBase run(ExampleSet dataset) { rule.getCoveredNegatives().addAll(negatives); rule.setRuleOrderNum(ruleset.getRules().size()); + rule.getConsequence().setCovering(positives); + double t = System.nanoTime(); carryOn = (finder.grow(rule, dataset, uncoveredPositives) > 0); ruleset.setGrowingTime( ruleset.getGrowingTime() + (System.nanoTime() - t) / 1e9); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastRegressionFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastRegressionFinder.java index 25188a63..1d33cce2 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastRegressionFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/ContrastRegressionFinder.java @@ -75,6 +75,7 @@ public ContrastRegressionFinder(InductionParameters params) { params.setInductionMeasure(m); params.setPruningMeasure(new NegativeControlledMeasure(m, params.getMaxcovNegative())); params.setVotingMeasure(m); + params.setMeanBasedRegression(false); } /** diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/InductionParameters.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/InductionParameters.java index 2015850a..c09c93e1 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/InductionParameters.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/InductionParameters.java @@ -60,10 +60,11 @@ public class InductionParameters implements Serializable { private List minimumCoveredAll_list = new ArrayList(); private int maxPassesCount = 1; - private boolean meanBasedRegression = false; + private boolean meanBasedRegression = true; private boolean controlAprioriPrecision = true; private boolean approximateInduction = false; + private int approximateBinsCount = 100; public IQualityMeasure getInductionMeasure() {return inductionMeasure;} public void setInductionMeasure(IQualityMeasure inductionMeasure) {this.inductionMeasure = inductionMeasure;} @@ -134,6 +135,9 @@ public void setMaxRuleCount(int maxRuleCount) { public boolean isApproximateInduction() { return approximateInduction; } public void setApproximateInduction(boolean v) { approximateInduction = v; } + public int getApproximateBinsCount() { return approximateBinsCount; } + public void setApproximateBinsCount(int v) { approximateBinsCount = v; } + public List getMinimumCoveredAll_list() { return minimumCoveredAll_list; } public void setMinimumCoveredAll_list(List minimumCovered) {this.minimumCoveredAll_list.addAll(minimumCovered);} @@ -161,7 +165,8 @@ public String toString() { "select_best_candidate=" + selectBestCandidate + "\n" + "max_passes_count=" + maxPassesCount + "\n" + "complementary_conditions=" + conditionComplementEnabled + "\n" + - "approximate_induction=" + approximateInduction + "\n"; + "approximate_induction=" + approximateInduction + "\n" + + "approximate_bins_count=" + (approximateInduction ? approximateBinsCount : "OFF") + "\n"; } diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertFinder.java index 86a10b2a..66c5aa25 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertFinder.java @@ -102,12 +102,18 @@ public void adjust( } covering.clear(); - rule.covers(dataset, covering, covering.positives, covering.negatives); + IntegerBitSet positives = new IntegerBitSet(dataset.size()); + IntegerBitSet negatives = new IntegerBitSet(dataset.size()); + rule.covers(dataset, covering, positives, negatives); + + // ugly + covering.positives = positives; + covering.negatives = negatives; rule.setCoveringInformation(covering); - rule.getCoveredPositives().setAll(covering.positives); - rule.getCoveredNegatives().setAll(covering.negatives); - + rule.setCoveredNegatives(negatives); + rule.setCoveredPositives(positives); + rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); } @@ -124,13 +130,14 @@ public int grow( boolean isRuleEmpty = rule.getPremise().getSubconditions().size() == 0; // get current covering - + IntegerBitSet positives = new IntegerBitSet(dataset.size()); + IntegerBitSet negatives = new IntegerBitSet(dataset.size()); Covering covering = new Covering(); - rule.covers(dataset, covering, covering.positives, covering.negatives); + rule.covers(dataset, covering, positives, negatives); - Set covered = new HashSet(); - covered.addAll(covering.positives); - covered.addAll(covering.negatives); + IntegerBitSet covered = new IntegerBitSet(dataset.size()); + covered.addAll(positives); + covered.addAll(negatives); Set allowedAttributes = new TreeSet(new AttributeComparator()); // add all attributes @@ -166,7 +173,7 @@ public int grow( continue; } - checkCandidate(dataset, rule, candidate, uncovered, bestEvaluation); + checkCandidate(dataset, rule, candidate, uncovered, covered, bestEvaluation); } if (bestEvaluation.condition != null) { @@ -180,7 +187,7 @@ public int grow( rule.getPremise().addSubcondition(bestEvaluation.condition); rule.setCoveringInformation(bestEvaluation.covering); - rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); + // rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); carryOn = true; Logger.log("Preferred condition " + rule.getPremise().getSubconditions().size() + " added: " @@ -224,16 +231,16 @@ public int grow( // update covering covering.clear(); - rule.covers(dataset, covering, covering.positives, covering.negatives); + positives.clear(); + negatives.clear(); + rule.covers(dataset, covering, positives, negatives); covered.clear(); - covered.addAll(covering.positives); - covered.addAll(covering.negatives); - - rule.getCoveredPositives().setAll(covering.positives); - rule.getCoveredNegatives().setAll(covering.negatives); + covered.addAll(positives); + covered.addAll(negatives); + rule.setCoveredPositives(positives); + rule.setCoveredNegatives(negatives); rule.setCoveringInformation(covering); - rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); Logger.log("Condition " + rule.getPremise().getSubconditions().size() + " added: " + rule.toString() + "\n", Level.FINER); @@ -282,7 +289,7 @@ public int grow( rule.getCoveredNegatives().setAll(covering.negatives); rule.setCoveringInformation(covering); - rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); + // rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); Logger.log("Condition " + rule.getPremise().getSubconditions().size() + " added: " + rule.toString() + "\n", Level.FINER); @@ -293,9 +300,27 @@ public int grow( } while (carryOn); } + covering.clear(); + positives.clear(); + negatives.clear(); + rule.covers(dataset, covering, positives, negatives); + + // ugly + covering.positives = positives; + covering.negatives = negatives; + rule.setCoveringInformation(covering); + + rule.setCoveredNegatives(negatives); + rule.setCoveredPositives(positives); + // if rule has been successfully grown int addedConditionsCount = rule.getPremise().getSubconditions().size() - initialConditionsCount; rule.setInducedContitionsCount(addedConditionsCount); + + if (addedConditionsCount > 0) { + rule.updateWeightAndPValue(dataset, covering, params.getVotingMeasure()); + } + return addedConditionsCount; } @@ -304,10 +329,11 @@ protected boolean checkCandidate( ExampleSet dataset, Rule rule, ConditionBase candidate, - Set uncovered, + Set uncovered, + Set covered, ConditionEvaluation currentBest) { - boolean ok = super.checkCandidate(dataset, rule, candidate, uncovered, currentBest); + boolean ok = super.checkCandidate(dataset, rule, candidate, uncovered, covered, currentBest); // verify knowledge only on elementary conditions if (ok && candidate instanceof ElementaryCondition) { diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertSnC.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertSnC.java index 94d18d55..7e66feee 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertSnC.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionExpertSnC.java @@ -53,25 +53,28 @@ public RuleSetBase run(final ExampleSet dataset) { Logger.log("RegressionExpertSnC.run()\n", Level.FINE); double beginTime; beginTime = System.nanoTime(); - - RuleSetBase ruleset = factory.create(dataset); + + SortedExampleSetEx sortedDataset = (SortedExampleSetEx)finder.preprocess(dataset); + RuleSetBase ruleset = factory.create(sortedDataset); Attribute label = dataset.getAttributes().getLabel(); - SortedExampleSet ses = new SortedExampleSetEx(dataset, label, SortedExampleSet.INCREASING); - ses.recalculateAttributeStatistics(ses.getAttributes().getLabel()); + //SortedExampleSet ses = new SortedExampleSetEx(dataset, label, SortedExampleSet.INCREASING); + + sortedDataset.recalculateAttributeStatistics(sortedDataset.getAttributes().getLabel()); if (factory.getType() == RuleFactory.REGRESSION) { - double median = ses.getExample(ses.size() / 2).getLabel(); + double median = sortedDataset.getExample(sortedDataset.size() / 2).getLabel(); RegressionRuleSet tmp = (RegressionRuleSet)ruleset; tmp.setDefaultValue(median); } - Set uncovered = new HashSet(); + //Set uncovered = new HashSet(); + Set uncovered = new IntegerBitSet(dataset.size()); double weighted_PN = 0; // at the beginning rule set does not cover any examples - for (int id = 0; id < ses.size(); ++id) { + for (int id = 0; id < sortedDataset.size(); ++id) { uncovered.add(id); - Example ex = ses.getExample(id); - double w = ses.getAttributes().getWeight() == null ? 1.0 : ex.getWeight(); + Example ex = sortedDataset.getExample(id); + double w = sortedDataset.getAttributes().getWeight() == null ? 1.0 : ex.getWeight(); weighted_PN += w; } @@ -97,17 +100,17 @@ public RuleSetBase run(final ExampleSet dataset) { rule.setCoveredNegatives(new IntegerBitSet(dataset.size())); rule.getCoveredPositives().setAll(); - erf.adjust(rule, dataset, uncovered); + erf.adjust(rule, sortedDataset, uncovered); Logger.log("Expert rule: " + rule.toString() + "\n", Level.FINE); double t = System.nanoTime(); - finder.grow(rule, ses, uncovered); + finder.grow(rule, sortedDataset, uncovered); ruleset.setGrowingTime( ruleset.getGrowingTime() + (System.nanoTime() - t) / 1e9); if (params.isPruningEnabled()) { Logger.log("Before prunning: " + rule.toString() + "\n" , Level.FINE); t = System.nanoTime(); - finder.prune(rule, ses, uncovered); + finder.prune(rule, sortedDataset, uncovered); ruleset.setPruningTime( ruleset.getPruningTime() + (System.nanoTime() - t) / 1e9); } Logger.log("Candidate rule: " + rule.toString() + "\n", Level.FINE); @@ -115,7 +118,7 @@ public RuleSetBase run(final ExampleSet dataset) { Logger.log( "\r" + StringUtils.repeat("\t", 10) + "\r", Level.INFO); Logger.log("\t" + totalExpertRules + " expert rules, " + (++totalAutoRules) + " auto rules" , Level.INFO); - finder.postprocess(rule, ses); + finder.postprocess(rule, sortedDataset); ruleset.addRule(rule); // remove examples covered by the rule and update statistics @@ -123,8 +126,8 @@ public RuleSetBase run(final ExampleSet dataset) { uncovered.removeAll(rule.getCoveredNegatives()); uncovered_pn = 0; for (int id : uncovered) { - Example e = ses.getExample(id); - uncovered_pn += ses.getAttributes().getWeight() == null ? 1.0 : e.getWeight(); + Example e = sortedDataset.getExample(id); + uncovered_pn += sortedDataset.getAttributes().getWeight() == null ? 1.0 : e.getWeight(); } } @@ -151,14 +154,14 @@ public RuleSetBase run(final ExampleSet dataset) { rule.getCoveredPositives().setAll(); double t = System.nanoTime(); - carryOn = (finder.grow(rule, ses, uncovered) > 0); + carryOn = (finder.grow(rule, sortedDataset, uncovered) > 0); ruleset.setGrowingTime( ruleset.getGrowingTime() + (System.nanoTime() - t) / 1e9); if (carryOn) { if (params.isPruningEnabled()) { Logger.log("Before prunning: " + rule.toString() + "\n" , Level.FINE); t = System.nanoTime(); - finder.prune(rule, ses, uncovered); + finder.prune(rule, sortedDataset, uncovered); ruleset.setPruningTime( ruleset.getPruningTime() + (System.nanoTime() - t) / 1e9); } Logger.log("Candidate rule: " + rule.toString() + "\n", Level.FINE); @@ -171,8 +174,8 @@ public RuleSetBase run(final ExampleSet dataset) { uncovered_pn = 0; for (int id : uncovered) { - Example e = ses.getExample(id); - uncovered_pn += ses.getAttributes().getWeight() == null ? 1.0 : e.getWeight(); + Example e = sortedDataset.getExample(id); + uncovered_pn += sortedDataset.getAttributes().getWeight() == null ? 1.0 : e.getWeight(); } // stop if number of examples remaining is less than threshold @@ -184,7 +187,7 @@ public RuleSetBase run(final ExampleSet dataset) { if (uncovered.size() == previouslyUncovered) { carryOn = false; } else { - finder.postprocess(rule, ses); + finder.postprocess(rule, sortedDataset); ruleset.addRule(rule); Logger.log( "\r" + StringUtils.repeat("\t", 10) + "\r", Level.INFO); Logger.log("\t" + totalExpertRules + " expert rules, " + (++totalAutoRules) + " auto rules" , Level.INFO); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionFinder.java index b73372c3..93f15f20 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionFinder.java @@ -19,6 +19,7 @@ import com.rapidminer.example.Attribute; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; +import com.rapidminer.example.set.SortedExampleSet; import java.security.InvalidParameterException; import java.util.*; @@ -38,6 +39,13 @@ public RegressionFinder(final InductionParameters params) { RegressionRule.setUseMean(params.isMeanBasedRegression()); } + @Override + public ExampleSet preprocess(ExampleSet trainSet) { + Attribute label = trainSet.getAttributes().getLabel(); + SortedExampleSetEx ses = new SortedExampleSetEx(trainSet, label, SortedExampleSet.INCREASING); + return ses; + } + protected ElementaryCondition induceCondition_mean( final Rule rule, final ExampleSet dataset, @@ -158,7 +166,7 @@ class Stats{ double new_n = stats[c].sum_new_w; double new_p = 0; - // iterate over elements from the entire set + // iterate over elements from the positive range for (int j = lo; j < hi; ++j) { if (covs[c].contains(j)) { double wj = set.weights[j]; @@ -223,13 +231,13 @@ class Stats{ // evaluate straight condition ElementaryCondition candidate = new ElementaryCondition( attr.getName(), new SingletonSet((double)i, attr.getMapping().getValues())); - checkCandidate(dataset, rule, candidate, uncovered, best); + checkCandidate(dataset, rule, candidate, uncovered, covered, best); // evaluate complementary condition if enabled if (params.isConditionComplementEnabled()) { candidate = new ElementaryCondition( attr.getName(), new SingletonSetComplement((double) i, attr.getMapping().getValues())); - checkCandidate(dataset, rule, candidate, uncovered, best); + checkCandidate(dataset, rule, candidate, uncovered, covered, best); } } } @@ -319,11 +327,11 @@ protected ElementaryCondition induceCondition( // evaluate left-side condition a < v ElementaryCondition candidate = new ElementaryCondition(attr.getName(), Interval.create_le(midpoint)); - checkCandidate(dataset, rule, candidate, uncovered, best); + checkCandidate(dataset, rule, candidate, uncovered, covered, best); // evaluate right-side condition v <= a candidate = new ElementaryCondition(attr.getName(), Interval.create_geq(midpoint)); - checkCandidate(dataset, rule, candidate, uncovered, best); + checkCandidate(dataset, rule, candidate, uncovered, covered, best); } } else { // try all possible conditions @@ -331,13 +339,13 @@ protected ElementaryCondition induceCondition( // evaluate straight condition ElementaryCondition candidate = new ElementaryCondition( attr.getName(), new SingletonSet((double)i, attr.getMapping().getValues())); - checkCandidate(dataset, rule, candidate, uncovered, best); + checkCandidate(dataset, rule, candidate, uncovered, covered, best); // evaluate complementary condition if enabled if (params.isConditionComplementEnabled()) { candidate = new ElementaryCondition( attr.getName(), new SingletonSetComplement((double) i, attr.getMapping().getValues())); - checkCandidate(dataset, rule, candidate, uncovered, best); + checkCandidate(dataset, rule, candidate, uncovered, covered, best); } } } @@ -379,7 +387,8 @@ protected boolean checkCandidate( ExampleSet dataset, Rule rule, ConditionBase candidate, - Set uncovered, + Set uncovered, + Set covered, ConditionEvaluation currentBest) { try { diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionSnC.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionSnC.java index 36131886..17bb1bb7 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionSnC.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/RegressionSnC.java @@ -47,25 +47,27 @@ public RuleSetBase run(final ExampleSet dataset) { Logger.log("RegressionSnC.run()\n", Level.FINE); - RuleSetBase ruleset = factory.create(dataset); + SortedExampleSetEx sortedDataset = (SortedExampleSetEx)finder.preprocess(dataset); + RuleSetBase ruleset = factory.create(sortedDataset); Attribute label = dataset.getAttributes().getLabel(); - SortedExampleSetEx ses = new SortedExampleSetEx(dataset, label, SortedExampleSet.INCREASING); - ses.recalculateAttributeStatistics(ses.getAttributes().getLabel()); + //SortedExampleSetEx ses = new SortedExampleSetEx(dataset, label, SortedExampleSet.INCREASING); + + sortedDataset.recalculateAttributeStatistics(sortedDataset.getAttributes().getLabel()); if (factory.getType() == RuleFactory.REGRESSION) { - double median = ses.getExample(ses.size() / 2).getLabel(); + double median = sortedDataset.getExample(sortedDataset.size() / 2).getLabel(); RegressionRuleSet tmp = (RegressionRuleSet)ruleset; tmp.setDefaultValue(median); // use this even in mean-based variant } - Set uncovered = new HashSet(); - //Set uncovered = new IntegerBitSet(ses.size()); + //Set uncovered = new HashSet(); + Set uncovered = new IntegerBitSet(sortedDataset.size()); double weighted_PN = 0; // at the beginning rule set does not cover any examples - for (int id = 0; id < ses.size(); ++id) { + for (int id = 0; id < sortedDataset.size(); ++id) { uncovered.add(id); - Example ex = ses.getExample(id); - double w = ses.getAttributes().getWeight() == null ? 1.0 : ex.getWeight(); + Example ex = sortedDataset.getExample(id); + double w = sortedDataset.getAttributes().getWeight() == null ? 1.0 : ex.getWeight(); weighted_PN += w; } @@ -92,21 +94,21 @@ public RuleSetBase run(final ExampleSet dataset) { rule.setRuleOrderNum(ruleset.getRules().size()); double t = System.nanoTime(); - carryOn = (finder.grow(rule, ses, uncovered) > 0); + carryOn = (finder.grow(rule, sortedDataset, uncovered) > 0); ruleset.setGrowingTime( ruleset.getGrowingTime() + (System.nanoTime() - t) / 1e9); if (carryOn) { if (params.isPruningEnabled()) { Logger.log("Before prunning: " + rule.toString() + "\n" , Level.FINE); t = System.nanoTime(); - finder.prune(rule, ses, uncovered); + finder.prune(rule, sortedDataset, uncovered); ruleset.setPruningTime( ruleset.getPruningTime() + (System.nanoTime() - t) / 1e9); } Logger.log("Candidate rule: " + rule.toString() + "\n", Level.FINE); Logger.log(".", Level.INFO); Covering covering = new Covering(); - rule.covers(ses, covering, covering.positives, covering.negatives); + rule.covers(sortedDataset, covering, covering.positives, covering.negatives); // remove covered examples int previouslyUncovered = uncovered.size(); @@ -115,7 +117,7 @@ public RuleSetBase run(final ExampleSet dataset) { uncovered_pn = 0; for (int id : uncovered) { - Example e = dataset.getExample(id); + Example e = sortedDataset.getExample(id); uncovered_pn += dataset.getAttributes().getWeight() == null ? 1.0 : e.getWeight(); } @@ -128,7 +130,7 @@ public RuleSetBase run(final ExampleSet dataset) { if (uncovered.size() == previouslyUncovered) { carryOn = false; } else { - finder.postprocess(rule, ses); + finder.postprocess(rule, sortedDataset); ruleset.addRule(rule); Logger.log( "\r" + StringUtils.repeat("\t", 10) + "\r", Level.INFO); Logger.log("\t" + (++totalRules) + " rules" , Level.INFO); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankExpertFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankExpertFinder.java index 6f51942d..1973d834 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankExpertFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankExpertFinder.java @@ -16,15 +16,16 @@ import adaa.analytics.rules.logic.quality.IQualityMeasure; import adaa.analytics.rules.logic.quality.LogRank; -import adaa.analytics.rules.logic.representation.KaplanMeierEstimator; +import adaa.analytics.rules.logic.representation.*; -import adaa.analytics.rules.logic.representation.Rule; -import adaa.analytics.rules.logic.representation.SurvivalRule; +import com.rapidminer.example.Attribute; import com.rapidminer.example.ExampleSet; +import com.rapidminer.example.set.SortedExampleSet; import com.rapidminer.tools.container.Pair; import java.util.HashSet; import java.util.Set; +import java.util.logging.Level; /** * Class for growing and pruning log rank-based survival rules with user's knowledge. @@ -36,6 +37,15 @@ public class SurvivalLogRankExpertFinder extends RegressionExpertFinder { public SurvivalLogRankExpertFinder(InductionParameters params) { super(params); + this.params.setMeanBasedRegression(false); + } + + SurvivalLogRankFinder.Implementation implementation = new SurvivalLogRankFinder.Implementation(); + + + @Override + public ExampleSet preprocess(ExampleSet trainSet) { + return implementation.preprocess(trainSet); } /** @@ -50,9 +60,19 @@ public void postprocess( final Rule rule, final ExampleSet dataset) { - Covering cov = rule.covers(dataset); - Set covered = cov.positives; - KaplanMeierEstimator kme = new KaplanMeierEstimator(dataset, covered); - ((SurvivalRule)rule).setEstimator(kme); + super.postprocess(rule, dataset); + implementation.postprocess(rule, dataset); + } + + @Override + protected boolean checkCandidate( + ExampleSet dataset, + Rule rule, + ConditionBase candidate, + Set uncovered, + Set covered, + ConditionEvaluation currentBest) { + + return implementation.checkCandidate(dataset, rule, candidate, uncovered, covered, currentBest, this); } } diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankFinder.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankFinder.java index c51b3a19..d4bd9553 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankFinder.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/induction/SurvivalLogRankFinder.java @@ -16,16 +16,17 @@ import adaa.analytics.rules.logic.quality.IQualityMeasure; import adaa.analytics.rules.logic.quality.LogRank; -import adaa.analytics.rules.logic.representation.KaplanMeierEstimator; +import adaa.analytics.rules.logic.representation.*; -import adaa.analytics.rules.logic.representation.Rule; -import adaa.analytics.rules.logic.representation.SurvivalRule; +import com.rapidminer.example.Attribute; import com.rapidminer.example.ExampleSet; +import com.rapidminer.example.set.SortedExampleSet; import com.rapidminer.tools.container.Pair; import org.jetbrains.annotations.NotNull; import java.util.HashSet; import java.util.Set; +import java.util.logging.Level; /** * Class for growing and pruning log rank-based survival rules. @@ -35,11 +36,114 @@ */ public class SurvivalLogRankFinder extends RegressionFinder{ + public static class Implementation { + public ExampleSet preprocess(ExampleSet trainSet) { + Attribute survTime = trainSet.getAttributes().getSpecial(SurvivalRule.SURVIVAL_TIME_ROLE); + SortedExampleSetEx ses = new SortedExampleSetEx(trainSet, survTime, SortedExampleSet.INCREASING); + return ses; + } + + public void postprocess( + final Rule rule, + final ExampleSet dataset) { + + KaplanMeierEstimator kme = new KaplanMeierEstimator(dataset, rule.getCoveredPositives()); + ((SurvivalRule)rule).setEstimator(kme); + } + + protected boolean checkCandidate( + ExampleSet dataset, + Rule rule, + ConditionBase candidate, + Set uncovered, + Set covered, + ConditionEvaluation currentBest, + RegressionFinder finder) { + + try { + + IntegerBitSet conditionCovered = new IntegerBitSet(dataset.size()); + candidate.evaluate(dataset, conditionCovered); + + IntegerBitSet ruleCovered = conditionCovered.clone(); + ruleCovered.retainAll(covered); + + double p = 0; + double new_p = 0; + + if (dataset.getAttributes().getWeight() == null) { + // unweighted examples + p = conditionCovered.calculateIntersectionSize((IntegerBitSet) covered); + new_p = conditionCovered.calculateIntersectionSize((IntegerBitSet) uncovered, (IntegerBitSet) covered); + + } else { + // calculate weights of newly covered examples + for (int id : conditionCovered) { + if (covered.contains(id)) { + double w = dataset.getExample(id).getWeight(); + p += w; + if (uncovered.contains(id)) { + new_p += w; + } + } + } + } + + if (finder.checkCoverage(p, 0, new_p, 0, dataset.size(), 0, uncovered.size(), rule.getRuleOrderNum())) { + Covering cov = new Covering(); + cov.positives = ruleCovered; + cov.weighted_p = p; + + double quality = finder.params.getInductionMeasure().calculate(dataset, cov); + + if (candidate instanceof ElementaryCondition) { + ElementaryCondition ec = (ElementaryCondition) candidate; + quality = finder.modifier.modifyQuality(quality, ec.getAttribute(), cov.weighted_p, new_p); + } + + if (quality > currentBest.quality || + (quality == currentBest.quality && (new_p > currentBest.covered || currentBest.opposite))) { + + Logger.log("\t\tCurrent best: " + candidate + " (p=" + cov.weighted_p + + ", new_p=" + (double) new_p + + ", P=" + cov.weighted_P + + ", mean_y=" + cov.mean_y + ", mean_y2=" + cov.mean_y2 + ", stddev_y=" + cov.stddev_y + + ", quality=" + quality + "\n", Level.FINEST); + + candidate.setCovering(conditionCovered); + + currentBest.quality = quality; + currentBest.condition = candidate; + currentBest.covered = new_p; + currentBest.covering = cov; + currentBest.opposite = (candidate instanceof ElementaryCondition) && + (((ElementaryCondition) candidate).getValueSet() instanceof SingletonSetComplement); + + //rule.setWeight(quality); + return true; + } + } + + } catch (Exception e) { + e.printStackTrace(); + } + return false; + } + } + + private Implementation implementation = new Implementation(); + public SurvivalLogRankFinder(InductionParameters params) { super(params); + this.params.setMeanBasedRegression(false); // TODO Auto-generated constructor stub } + @Override + public ExampleSet preprocess(ExampleSet trainSet) { + return implementation.preprocess(trainSet); + } + /** * Postprocesses a rule. * @@ -53,10 +157,18 @@ public void postprocess( final ExampleSet dataset) { super.postprocess(rule, dataset); + implementation.postprocess(rule, dataset); + } + + @Override + protected boolean checkCandidate( + ExampleSet dataset, + Rule rule, + ConditionBase candidate, + Set uncovered, + Set covered, + ConditionEvaluation currentBest) { - Covering cov = rule.covers(dataset); - Set covered = cov.positives; - KaplanMeierEstimator kme = new KaplanMeierEstimator(dataset, covered); - ((SurvivalRule)rule).setEstimator(kme); + return implementation.checkCandidate(dataset, rule, candidate, uncovered, covered, currentBest, this); } } diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/quality/LogRank.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/quality/LogRank.java index 08ef973b..52522eef 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/quality/LogRank.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/quality/LogRank.java @@ -16,8 +16,10 @@ import adaa.analytics.rules.logic.induction.ContingencyTable; import adaa.analytics.rules.logic.induction.Covering; +import adaa.analytics.rules.logic.representation.IntegerBitSet; import adaa.analytics.rules.logic.representation.KaplanMeierEstimator; +import adaa.analytics.rules.logic.representation.Logger; import com.rapidminer.example.ExampleSet; import org.apache.commons.math3.distribution.ChiSquaredDistribution; @@ -26,6 +28,8 @@ import java.io.Serializable; import java.util.HashSet; import java.util.Set; +import java.util.TreeSet; +import java.util.logging.Level; /** * Class representing log-rank test. @@ -53,14 +57,14 @@ public double calculate(double p, double n, double P, double N) { public double calculate(ExampleSet dataset, ContingencyTable ct) { Covering cov = (Covering)ct; - Set coveredIndices = cov.positives; // in survival rules all examples are classified as positives - Set uncoveredIndices = new HashSet(); - for (int i = 0; i < dataset.size(); ++i) { - if (!coveredIndices.contains(i)) { - uncoveredIndices.add(i); - } + IntegerBitSet coveredIndices = (IntegerBitSet) cov.positives; // in survival rules all examples are classified as positives + if (coveredIndices == null) { + assert false: "LogRank.calculate() requires IntegerBitSet as an argument"; } + IntegerBitSet uncoveredIndices = new IntegerBitSet(dataset.size()); + coveredIndices.negate(uncoveredIndices); + KaplanMeierEstimator coveredEstimator = new KaplanMeierEstimator(dataset, coveredIndices); KaplanMeierEstimator uncoveredEstimator = new KaplanMeierEstimator(dataset, uncoveredIndices); @@ -72,34 +76,72 @@ public double calculate(ExampleSet dataset, ContingencyTable ct) { public Pair compareEstimators(KaplanMeierEstimator kme1, KaplanMeierEstimator kme2) { Pair res = new Pair(0.0, 0.0); - - Set eventTimes = new HashSet(); - eventTimes.addAll(kme1.getTimes()); - eventTimes.addAll(kme2.getTimes()); - + // fixme: - if (kme1.getTimes().size() == 0 || kme2.getTimes().size() == 0) { + if (kme1.filled == 0 || kme2.filled == 0) { return res; } - - double x = 0; - double y = 0; - for (double time : eventTimes) { - double m1 = kme1.getEventsCountAt(time); - double n1 = kme1.getRiskSetCountAt(time); - double m2 = kme2.getEventsCountAt(time); - double n2 = kme2.getRiskSetCountAt(time); + double x = 0; + double y = 0; + + int id1 = 0; + int id2 = 0; + + for (; id1 < kme1.filled; ++id1) { + double m1, n1, m2, n2; - //Debug.WriteLine(string.Format("time={0}, m1={1} m2={2} n1={3} n2={4}", time, m1, m2, n1, n2)); + for (; (id2 < kme2.filled) && (kme2.times[id2] < kme1.times[id1]); id2++) { + // point only in kme2 + m1 = 0; + n1 = kme1.atRiskCounts[id1]; - double e2 = (n2 / (n1 + n2)) * (m1 + m2); + m2 = kme2.eventsCounts[id2]; + n2 = kme2.atRiskCounts[id2]; + + double e2 = (n2 / (n1 + n2)) * (m1 + m2); + x += m2 - e2; + double n = n1 + n2; + y += (n1 * n2 * (m1 + m2) * (n - m1 - m2)) / (n * n * (n - 1)); + } - x += m2 - e2; + m1 = kme1.eventsCounts[id1]; + n1 = kme1.atRiskCounts[id1]; - double n = n1 + n2; - y += (n1 * n2 * (m1 + m2) * (n - m1 - m2)) / (n * n * (n - 1)); - } + if (id2 < kme2.filled && kme1.times[id1] == kme2.times[id2]) { + // point in both + m2 = kme2.eventsCounts[id2]; + n2 = kme2.atRiskCounts[id2]; + ++id2; + + } else { + // point only in kme1 + m2 = 0; + n2 = kme2.atRiskCounts[Math.min(id2, kme2.filled - 1)]; + } + + double e2 = (n2 / (n1 + n2)) * (m1 + m2); + x += m2 - e2; + double n = n1 + n2; + y += (n1 * n2 * (m1 + m2) * (n - m1 - m2)) / (n * n * (n - 1)); + } + + // remaining kme2 points + for (; id2 < kme2.filled; id2++) { + double m1, n1, m2, n2; + + // point only in kme2 + m1 = 0; + n1 = kme1.atRiskCounts[kme1.filled - 1]; + + m2 = kme2.eventsCounts[id2]; + n2 = kme2.atRiskCounts[id2]; + + double e2 = (n2 / (n1 + n2)) * (m1 + m2); + x += m2 - e2; + double n = n1 + n2; + y += (n1 * n2 * (m1 + m2) * (n - m1 - m2)) / (n * n * (n - 1)); + } res.setFirst((x * x) / y); res.setSecond(1.0 - dist.cumulativeProbability(res.getFirst())); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ConditionBase.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ConditionBase.java index 8ddfd3b0..c59a7725 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ConditionBase.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/ConditionBase.java @@ -14,6 +14,7 @@ ******************************************************************************/ package adaa.analytics.rules.logic.representation; +import com.fasterxml.jackson.annotation.JsonIgnore; import com.rapidminer.example.Example; import com.rapidminer.example.ExampleSet; @@ -50,6 +51,9 @@ public enum Type {FORCED, PREFERRED, NORMAL}; /** Condition type. */ protected Type type = Type.NORMAL; + /** Optional integer bit set for storing condition coverage. */ + protected IntegerBitSet covering = null; + /** Gets {@link #disabled} */ public boolean isDisabled() { return disabled; } /** Sets {@link #disabled} */ @@ -60,6 +64,11 @@ public enum Type {FORCED, PREFERRED, NORMAL}; /** Sets {@link #type} */ public void setType(Type t) { type = t; } + /** Gets {@link #covering}. */ + public IntegerBitSet getCovering() { return covering; } + /** Sets {@link #covering}. */ + public void setCovering(IntegerBitSet c) { covering = c; } + /** * Check whether the condition is prunable (non-FORCED and non-PREFERRED). * @return Value indicating whether condition is prunable. diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/IntegerBitSet.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/IntegerBitSet.java index bd6dbce7..e7a7158a 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/IntegerBitSet.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/IntegerBitSet.java @@ -16,16 +16,23 @@ import com.sun.jna.platform.win32.WinDef; +import java.io.IOException; +import java.io.ObjectInputStream; +import java.io.ObjectOutputStream; +import java.io.Serializable; import java.util.Collection; import java.util.Iterator; import java.util.Set; /** * Upper-bounded set of integers represented internally as a bit vector. + * + * NOTE: While class implements Serializable interface, serialization/deserialization methods do nothing! + * * @author Adam Gudys * */ -public class IntegerBitSet implements Set { +public class IntegerBitSet implements Set, Serializable { /** * Iterator for {@link #adaa.analytics.rules.logic.representation.IntegerBitSet}. @@ -102,9 +109,9 @@ private boolean increment() { } } - static final int OFFSET_MASK = 63; + public static final int OFFSET_MASK = 63; - static final int ID_SHIFT = 6; + public static final int ID_SHIFT = 6; /** Array of words for storing bits. */ private long[] words; @@ -114,6 +121,8 @@ private boolean increment() { /** Gets {@link #maxElement}. */ public int getMaxElement() { return maxElement; } + + public long[] getRawTable() { return words; } /** * Allocates words array for storing bits. @@ -125,6 +134,20 @@ public IntegerBitSet(int maxElement) { words = new long[wordsCount]; } + + /** + * Allocates words array for storing bits. + * @param maxElement Max element that can be stored in the set. + */ + public IntegerBitSet(int maxElement, boolean fill) { + this.maxElement = maxElement; + int wordsCount = (maxElement + Long.SIZE - 1) / Long.SIZE; + words = new long[wordsCount]; + + if (fill) { + this.setAll(); + } + } /** * Adds new integer to the set (sets an appropriate bit). @@ -468,4 +491,12 @@ public boolean filteredCompare(IntegerBitSet arg0, IntegerBitSet arg1) { return true; } + + /** Empty deserialization method */ + private void readObject(ObjectInputStream ois) throws IOException, ClassNotFoundException { + } + + /** Empty serialization method */ + private void writeObject(ObjectOutputStream os) throws IOException { + } } diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/KaplanMeierEstimator.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/KaplanMeierEstimator.java index 51a366c7..71cf1c5b 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/KaplanMeierEstimator.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/KaplanMeierEstimator.java @@ -20,6 +20,7 @@ import com.rapidminer.example.ExampleSet; import java.io.Serializable; +import java.security.InvalidParameterException; import java.util.*; /** @@ -33,89 +34,59 @@ public class KaplanMeierEstimator implements Serializable { private static final long serialVersionUID = -6949465091584014494L; /** Array of survival estimator points */ - protected ArrayList survInfo = new ArrayList(); - - /** + //protected ArrayList survInfo = new ArrayList(); + + public static final int NotAssigned = Integer.MIN_VALUE; + + public double[] times; + public int[] eventsCounts; + public int[] censoredCounts; + public int[] atRiskCounts; + public double[] probabilities; + + public int filled = 0; + + + /** * Adds a new time point to the estimator. * @param time Survival time. * @param probability Survival probability. */ + /* public void addSurvInfo(double time, double probability) { survInfo.add(new SurvInfo(time, probability)); } - + */ + /** * Creates empty instance. */ public KaplanMeierEstimator() {} - + /** * Generates survival estimator function from survival data. * @param data Example set with attribute of {@link adaa.analytics.rules.logic.representation.SurvivalRule#SURVIVAL_TIME_ROLE}. */ public KaplanMeierEstimator(ExampleSet data) { - Attribute survTime = data.getAttributes().getSpecial(SurvivalRule.SURVIVAL_TIME_ROLE); - Attribute survStat = data.getAttributes().getLabel(); - - this.survInfo.ensureCapacity(data.size()); - SurvInfo info[] = new SurvInfo[data.size()]; - - int j = 0; - for (Example e : data) { - double t = e.getValue(survTime); - boolean isCensored = (e.getValue(survStat) == 0); - - int eventsCount = isCensored ? 0 : 1; - info[j] = new SurvInfo(t, eventsCount, 1 - eventsCount); - ++j; - } - - Arrays.sort(info, new SurvInfoComparer(SurvInfoComparer.By.TimeAsc)); - assert info[0].getTime() <= info[info.length - 1].getTime(); - int atRiskCount = data.size(); - - int idx = 0; - while (idx < info.length) { - double t = info[idx].getTime(); - int startIdx = idx; - - while (idx < info.length && info[idx].getTime() == t) { - idx++; - } - - int eventsAtTimeCount = 0; - int censoredAtTimeCount = 0; - for (int i = startIdx; i < idx; i++) { - if (info[i].getEventsCount() == 1) { - eventsAtTimeCount++; - } else { - assert (info[i].getCensoredCount() == 1); - censoredAtTimeCount++; - } - } - - SurvInfo si = new SurvInfo(t, eventsAtTimeCount, censoredAtTimeCount); - si.setAtRiskCount(atRiskCount); - - this.survInfo.add(si); - - atRiskCount -= eventsAtTimeCount; - atRiskCount -= censoredAtTimeCount; - } - - this.calculateProbability(); + this(data, new IntegerBitSet(data.size(), true)); } - + /** * Converts estimator to the text. * @return Estimator in the text form. */ public String save() { StringBuilder sb = new StringBuilder(); + /* sb.append(survInfo.size() + ":"); for (SurvInfo si: survInfo) { sb.append(si.time + " " + si.probability + " " ); } + */ + sb.append(filled + ":"); + for (int i = 0; i < filled; ++i) { + sb.append(times[i] + " " + probabilities[i] + " " ); + } return sb.toString(); } @@ -125,14 +96,21 @@ public String save() { */ public void load(String s) { int idx = s.indexOf(':'); - int count = Integer.parseInt(s.substring(0, idx)); + filled = Integer.parseInt(s.substring(0, idx)); s = s.substring(idx + 1); - + + String[] numbers = s.split(" "); + int num_idx = 0; + + this.reserve(filled); + + Arrays.fill(atRiskCounts, NotAssigned); + Arrays.fill(censoredCounts, NotAssigned); + Arrays.fill(eventsCounts, NotAssigned); + + /* survInfo = new ArrayList(count); - - String[] numbers = s.split(" "); - int num_idx = 0; - + for (int i = 0; i < count; ++i) { survInfo.add(new SurvInfo( @@ -140,6 +118,12 @@ public void load(String s) { Double.parseDouble(numbers[num_idx++]) )); } + */ + + for (int i = 0; i < filled; ++i) { + times[i] = Double.parseDouble(numbers[num_idx++]); + probabilities[i] = Double.parseDouble(numbers[num_idx++]); + } } /** @@ -148,58 +132,73 @@ public void load(String s) { * @param indices Indices of the examples to be taken into account when building the estimator. */ public KaplanMeierEstimator(ExampleSet data, Set indices) { - Attribute survTime = data.getAttributes().getSpecial(SurvivalRule.SURVIVAL_TIME_ROLE); - Attribute survStat = data.getAttributes().getLabel(); - - this.survInfo.ensureCapacity(indices.size()); - SurvInfo[] info = new SurvInfo[indices.size()]; - - int j = 0; - for (int id : indices) { - Example e = data.getExample(id); - double t = e.getValue(survTime); - boolean isCensored = (e.getValue(survStat) == 0); - - int eventsCount = isCensored ? 0 : 1; - info[j] = new SurvInfo(t, eventsCount, 1 - eventsCount); - ++j; + + SortedExampleSetEx set = (data instanceof SortedExampleSetEx) ? (SortedExampleSetEx)data : null; + if (set == null) { + throw new InvalidParameterException("RegressionRules support only ListedExampleSet example sets"); } - - Arrays.sort(info, new SurvInfoComparer(SurvInfoComparer.By.TimeAsc)); - assert info[0].getTime() <= info[info.length - 1].getTime(); + + this.reserve(indices.size()); + + Arrays.fill(atRiskCounts, NotAssigned); + Arrays.fill(censoredCounts, NotAssigned); + Arrays.fill(eventsCounts, NotAssigned); + Arrays.fill(probabilities, NotAssigned); + int atRiskCount = indices.size(); - int idx = 0; - while (idx < info.length) { - double t = info[idx].getTime(); - int startIdx = idx; - - while (idx < info.length && info[idx].getTime() == t) { - idx++; + Iterator it = indices.iterator(); + int i = 0; + + int eventsAtTimeCount = 0; + int censoredAtTimeCount = 0; + double prev_t = -1; + + while (it.hasNext()) { + int id = it.next(); + double t = set.survivalTimes[id]; + + // time point has changed - add surv info + if (t != prev_t && prev_t > 0) { + times[filled] = prev_t; + eventsCounts[filled] = eventsAtTimeCount; + censoredCounts[filled] = censoredAtTimeCount; + atRiskCounts[filled] = atRiskCount; + + ++filled; + + atRiskCount -= eventsAtTimeCount + censoredAtTimeCount; + + eventsAtTimeCount = 0; + censoredAtTimeCount = 0; } - int eventsAtTimeCount = 0; - int censoredAtTimeCount = 0; - for (int i = startIdx; i < idx; i++) { - if (info[i].getEventsCount() == 1) { - eventsAtTimeCount++; - } else { - assert (info[i].getCensoredCount() == 1); - censoredAtTimeCount++; - } + if (set.labels[id] == 1) { + ++eventsAtTimeCount; + } else { + ++censoredAtTimeCount; } - SurvInfo si = new SurvInfo(t, eventsAtTimeCount, censoredAtTimeCount); - si.setAtRiskCount(atRiskCount); + prev_t = t; + } - this.survInfo.add(si); + times[filled] = prev_t; + eventsCounts[filled] = eventsAtTimeCount; + censoredCounts[filled] = censoredAtTimeCount; + atRiskCounts[filled] = atRiskCount; - atRiskCount -= eventsAtTimeCount; - atRiskCount -= censoredAtTimeCount; - } + ++filled; this.calculateProbability(); } + + protected void reserve(int size) { + times = new double[size]; + atRiskCounts = new int[size]; + censoredCounts = new int[size]; + eventsCounts = new int[size]; + probabilities = new double[size]; + } /** * Average several estimators. @@ -212,26 +211,27 @@ public static KaplanMeierEstimator average(KaplanMeierEstimator[] estimators) { for (KaplanMeierEstimator e : estimators) { uniqueTime.addAll(e.getTimes()); } - - double[] time = new double[uniqueTime.size()]; - double[] probabilities = new double[uniqueTime.size()]; - + + // get averaged estimator + KaplanMeierEstimator avgKm = new KaplanMeierEstimator(); + avgKm.reserve(uniqueTime.size()); + avgKm.filled = uniqueTime.size(); + + Arrays.fill(avgKm.atRiskCounts, NotAssigned); + Arrays.fill(avgKm.censoredCounts, NotAssigned); + Arrays.fill(avgKm.eventsCounts, NotAssigned); + // average probabilities for all time points - Iterator t = uniqueTime.iterator(); - for (int i = 0; i < time.length; ++i) { - time[i] = t.next(); - double p = 0; + Iterator it= uniqueTime.iterator(); + for (int i = 0; i < avgKm.times.length; ++i) { + double t = it.next(); + double p = 0; + for (KaplanMeierEstimator e: estimators) { - p+= e.getProbabilityAt(time[i]); - + p += e.getProbabilityAt(t); } - probabilities[i] = p / estimators.length; - } - - // get averaged estimator - KaplanMeierEstimator avgKm = new KaplanMeierEstimator(); - for (int i = 0; i < time.length; i++) { - avgKm.addSurvInfo(time[i], probabilities[i]); + avgKm.times[i] = t; + avgKm.probabilities[i] = p / estimators.length; } return avgKm; @@ -242,11 +242,17 @@ public static KaplanMeierEstimator average(KaplanMeierEstimator[] estimators) { * @return Array of time points. */ public ArrayList getTimes() { - ArrayList times = new ArrayList(survInfo.size()); + /* + ArrayList times = new ArrayList(survInfo.size()); for (SurvInfo si : survInfo) { times.add(si.getTime()); } - return times; + */ + ArrayList out = new ArrayList(); + for (int i = 0; i < filled; ++i) { + out.add(times[i]); + } + return out; } /** @@ -255,75 +261,41 @@ public ArrayList getTimes() { * @return Survival probability. */ public double getProbabilityAt(double time) { - int idx = Collections.binarySearch(survInfo, new SurvInfo(time), new SurvInfoComparer(SurvInfoComparer.By.TimeAsc)); - + + int idx = Arrays.binarySearch(times, 0, filled, time); + if (idx >= 0) { - return this.survInfo.get(idx).getProbability(); + return probabilities[idx]; } - // bitwise complement of the index of the next element that is larger than item - // or, if there is no larger element, the bitwise complement of Count idx = ~idx; - int n = this.survInfo.size(); - if (idx == n) { - return this.survInfo.get(n - 1).getProbability(); + if (idx == filled) { + return probabilities[filled - 1]; } if (idx == 0) { return 1.0; } - double p = this.survInfo.get(idx - 1).getProbability(); + double p = probabilities[idx - 1]; assert (p != Double.NaN); return p; } - public double getTimeForProbability(double probability) { - assert probability >= 0.0 && probability <= 1.0; - //SurvInfo tmpSurvInfo = new SurvInfo(1.0, probability); - int idx = Collections.binarySearch(survInfo, new SurvInfo(1.0, probability), - new SurvInfoComparer(SurvInfoComparer.By.ProbabilityDesc)); - Optional lastThisProbability = survInfo. - stream(). - filter(i -> i.probability == probability). - reduce((first, second) -> second); - - if (idx >= 0) { - return lastThisProbability.get().getTime(); - } - - // bitwise complement of the index of the next element that is larger than item - // or, if there is no larger element, the bitwise complement of Count - idx = ~idx; - int n = this.survInfo.size(); - if (idx == n) { - return this.survInfo.get(n - 1).getTime(); - } - - if (idx == 0) { - return this.survInfo.get(idx).getTime(); - } - - double t = this.survInfo.get(idx).getTime(); - assert (t != Double.NaN); - return t; - } - /** * Gets number of events at given time point. * @param time Time. * @return Number of events. */ public int getEventsCountAt(double time) { - int idx = Collections.binarySearch(survInfo, new SurvInfo(time), new SurvInfoComparer(SurvInfoComparer.By.TimeAsc)); - - if (idx >= 0) { - return this.survInfo.get(idx).getEventsCount(); - } + int idx = Arrays.binarySearch(times, 0, filled, time); + if (idx >= 0) { + return eventsCounts[idx]; + } - return 0; + return 0; } /** @@ -332,22 +304,16 @@ public int getEventsCountAt(double time) { * @return Risk. */ public int getRiskSetCountAt(double time) { - int idx = Collections.binarySearch(survInfo, new SurvInfo(time), new SurvInfoComparer(SurvInfoComparer.By.TimeAsc)); - - if (idx >= 0) { - return this.survInfo.get(idx).getAtRiskCount(); - } - - // bitwise complement of the index of the next element that is larger than item - // or, if there is no larger element, the bitwise complement of Count - idx = ~idx; + int idx = Arrays.binarySearch(times, 0, filled, time); - int n = this.survInfo.size(); - if (idx == n) { - return this.survInfo.get(n - 1).getAtRiskCount(); + if (idx < 0) { + idx = ~idx; + if (idx == filled) { + --idx; + } } - return this.survInfo.get(idx).getAtRiskCount(); + return atRiskCounts[idx]; } /** @@ -355,14 +321,16 @@ public int getRiskSetCountAt(double time) { * @return Reversed estimator. */ public KaplanMeierEstimator reverse() { - KaplanMeierEstimator revKm = new KaplanMeierEstimator(); - for (int i = 0; i < this.survInfo.size(); i++) { - SurvInfo si = this.survInfo.get(i); - SurvInfo revSi = new SurvInfo(si.getTime(), si.getCensoredCount(), si.getEventsCount()); - revSi.setAtRiskCount(si.getAtRiskCount()); + KaplanMeierEstimator revKm = new KaplanMeierEstimator(); + + revKm.filled = filled; + revKm.probabilities = probabilities.clone(); + revKm.times = times.clone(); + revKm.atRiskCounts = atRiskCounts.clone(); + // switch places + revKm.censoredCounts = eventsCounts.clone(); + revKm.eventsCounts = censoredCounts.clone(); - revKm.survInfo.add(revSi); - } revKm.calculateProbability(); return revKm; } @@ -371,21 +339,15 @@ public KaplanMeierEstimator reverse() { * Fills the probabilities in K-M estimator. */ protected void calculateProbability() { - - //Debug.Assert(new HashSet(this.survInfo.Select(s => s.Time)).Count == this.survInfo.Count); - - for (int i = 0; i < this.survInfo.size(); i++) { - SurvInfo si = this.survInfo.get(i); - - assert (int)si.getProbability() == SurvInfo.NotAssigned; - - si.setProbability((si.getAtRiskCount() - si.getEventsCount()) / ((double)si.getAtRiskCount())); + for (int i = 0; i < filled; i++) { + // assert probabilities[i] == SurvInfo.NotAssigned; + probabilities[i] = (atRiskCounts[i] - eventsCounts[i]) / (double)atRiskCounts[i]; if (i > 0) { - si.setProbability(si.getProbability() * this.survInfo.get(i - 1).getProbability()); + probabilities[i] *= probabilities[i-1]; } - assert si.getProbability() >= 0 && si.getProbability() <= 1; - } + // assert probabilities[i] >= 0 && probabilities[i] <= 1; + } } /** diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/RegressionRule.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/RegressionRule.java index c85afb09..ff7a705d 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/RegressionRule.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/RegressionRule.java @@ -138,6 +138,7 @@ public void covers(ExampleSet set, ContingencyTable ct, Set positives, //initially, everything as negatives List orderedNegatives = new ArrayList(set.size()); + for (int id = 0; id < set.size(); ++id) { Example ex = set.getExample(id); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/SortedExampleSetEx.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/SortedExampleSetEx.java index bc5cd4a8..0c24f86f 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/SortedExampleSetEx.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/logic/representation/SortedExampleSetEx.java @@ -16,6 +16,7 @@ public class SortedExampleSetEx extends SortedExampleSet { public double[] weights; public double[] labelsWeighted; public double[] totalWeightsBefore; + public double[] survivalTimes; public double meanLabel = 0; @@ -47,6 +48,11 @@ protected final void fillLabelsAndWeights() { weights = new double[this.size()]; totalWeightsBefore = new double[this.size() + 1]; + Attribute survTime = this.getAttributes().getSpecial(SurvivalRule.SURVIVAL_TIME_ROLE); + if (survTime != null) { + survivalTimes = new double[this.size()]; + } + boolean weighted = getAttributes().getWeight() != null; for (Attribute a: this.getAttributes()) { @@ -66,6 +72,10 @@ protected final void fillLabelsAndWeights() { totalWeightsBefore[i] = sumWeights; meanLabel += y; + if (survTime != null) { + survivalTimes[i] = e.getValue(survTime); + } + for (Attribute a: this.getAttributes()) { if (!Double.isNaN(e.getValue(a))) { nonMissingVals.get(a).add(i); diff --git a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/operator/RuleGenerator.java b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/operator/RuleGenerator.java index 32115f1c..75053549 100644 --- a/adaa.analytics.rules/src/main/java/adaa/analytics/rules/operator/RuleGenerator.java +++ b/adaa.analytics.rules/src/main/java/adaa/analytics/rules/operator/RuleGenerator.java @@ -176,6 +176,8 @@ protected enum MeasureDestination { public static final String PARAMETER_CONTROL_APRORI_PRECISION = "control_apriori_precision"; public static final String PARAMETER_APPROXIMATE_INDUCTION = "approximate_induction"; + + public static final String PARAMETER_APPROXIMATE_BINS_COUNT = "approximate_bins_count"; protected OperatorCommandProxy operatorCommandProxy; @@ -237,6 +239,7 @@ public Model learn(ExampleSet exampleSet) throws OperatorException { params.setMeanBasedRegression(getParameterAsBoolean(PARAMETER_MEAN_BASED_REGRESSION)); params.setControlAprioriPrecision(getParameterAsBoolean(PARAMETER_CONTROL_APRORI_PRECISION)); params.setApproximateInduction(getParameterAsBoolean(PARAMETER_APPROXIMATE_INDUCTION)); + params.setApproximateBinsCount(getParameterAsInt(PARAMETER_APPROXIMATE_BINS_COUNT)); String tmp = getParameterAsString(PARAMETER_MINCOV_ALL); if (tmp.length() > 0) { @@ -447,6 +450,9 @@ PARAMETER_PENALTY_SATURATION, getParameterDescription(PARAMETER_PENALTY_SATURATI types.add(new ParameterTypeBoolean(PARAMETER_APPROXIMATE_INDUCTION, getParameterDescription(PARAMETER_APPROXIMATE_INDUCTION), defaultParams.isApproximateInduction())); + types.add(new ParameterTypeInt(PARAMETER_APPROXIMATE_BINS_COUNT, getParameterDescription(PARAMETER_APPROXIMATE_BINS_COUNT), + 0, 1000, defaultParams.getApproximateBinsCount())); + return types; } diff --git a/adaa.analytics.rules/test/resources/config/RegressionExpertSnCTest.xml b/adaa.analytics.rules/test/resources/config/RegressionExpertSnCTest.xml index 190f2d2d..b5cb5f2b 100644 --- a/adaa.analytics.rules/test/resources/config/RegressionExpertSnCTest.xml +++ b/adaa.analytics.rules/test/resources/config/RegressionExpertSnCTest.xml @@ -4,6 +4,7 @@ 4 + false true false false @@ -19,6 +20,7 @@ 4 + false true false false @@ -33,6 +35,7 @@ 4 + false true false false @@ -47,6 +50,7 @@ 4 + false true false false diff --git a/adaa.analytics.rules/test/resources/config/RegressionSnCTest.xml b/adaa.analytics.rules/test/resources/config/RegressionSnCTest.xml index a0b32a98..e3c8372b 100644 --- a/adaa.analytics.rules/test/resources/config/RegressionSnCTest.xml +++ b/adaa.analytics.rules/test/resources/config/RegressionSnCTest.xml @@ -3,6 +3,7 @@ 4 + false diff --git a/adaa.analytics.rules/test/utils/testcases/InductionParametersFactory.java b/adaa.analytics.rules/test/utils/testcases/InductionParametersFactory.java index 08ae2f89..4212db8c 100644 --- a/adaa.analytics.rules/test/utils/testcases/InductionParametersFactory.java +++ b/adaa.analytics.rules/test/utils/testcases/InductionParametersFactory.java @@ -46,6 +46,9 @@ public static InductionParameters make(HashMap paramMap) { throw new RuntimeException("Unknown name of classification pruning measure: " + paramMap.get(key)); } break; + case "mean_based_regression": + parameters.setMeanBasedRegression(paramMap.get(key).equals("true")); + break; } } return parameters;