Skip to content

Commit

Permalink
Permutect dataset engine outputs contig and read group indices, not n…
Browse files Browse the repository at this point in the history
…ames (#8860)
  • Loading branch information
davidbenjamin authored Jun 4, 2024
1 parent 2878ce5 commit 2a420e4
Show file tree
Hide file tree
Showing 6 changed files with 71 additions and 17 deletions.
8 changes: 7 additions & 1 deletion scripts/mutect2_wdl/mutect2.wdl
Original file line number Diff line number Diff line change
Expand Up @@ -314,6 +314,8 @@ workflow Mutect2 {
File? maf_segments = CalculateContamination.maf_segments
File? read_orientation_model_params = LearnReadOrientationModel.artifact_prior_table
File? m3_dataset = Concatenate.concatenated
File permutect_contigs_table = select_first(M2.permutect_contigs_table)
File permutect_read_groups_table = select_first(M2.permutect_read_groups_table)
}
}

Expand Down Expand Up @@ -442,6 +444,8 @@ task M2 {
touch bamout.bam
touch f1r2.tar.gz
touch dataset.txt
touch contigs.table
touch read-groups.table

if [[ ! -z "~{normal_reads}" ]]; then
gatk --java-options "-Xmx~{command_mem}m" GetSampleName -R ~{ref_fasta} -I ~{normal_reads} -O normal_names.txt -encode \
Expand Down Expand Up @@ -476,7 +480,7 @@ task M2 {

# If the variants for contamination and the intervals for this scatter don't intersect, GetPileupSummaries
# throws an error. However, there is nothing wrong with an empty intersection for our purposes; it simply doesn't
# contribute to the merged pileup summaries that we create downstream. We implement this by with array outputs.
# contribute to the merged pileup summaries that we create downstream. We implement this via array outputs.
# If the tool errors, no table is created and the glob yields an empty array.
set +e

Expand Down Expand Up @@ -516,6 +520,8 @@ task M2 {
Array[File] tumor_pileups = glob("*tumor-pileups.table")
Array[File] normal_pileups = glob("*normal-pileups.table")
File m3_dataset = "dataset.txt"
File permutect_contigs_table = "contigs.table"
File permutect_read_groups_table = "read-groups.table"
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,8 +50,9 @@ public static List<List<List<Integer>>> getReadVectors(final VariantContext vc,
final AlleleLikelihoods<Fragment, Haplotype> haplotypeLikelihoods,
final int refDownsample,
final int altDownsample,
final M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode) {
return getReadVectors(vc, samples, likelihoods, haplotypeLikelihoods, refDownsample, altDownsample, Collections.emptyMap(), mutect3DatasetMode);
final M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode,
final Map<String, Integer> readGroupIndices) {
return getReadVectors(vc, samples, likelihoods, haplotypeLikelihoods, refDownsample, altDownsample, Collections.emptyMap(), mutect3DatasetMode, readGroupIndices);
}

// returns Lists (in allele order) of lists of read vectors supporting each allele
Expand All @@ -62,7 +63,8 @@ public static List<List<List<Integer>>> getReadVectors(final VariantContext vc,
final int refDownsample,
final int altDownsample,
final Map<Allele, Integer> altDownsampleMap,
final M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode) {
final M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode,
final Map<String, Integer> readGroupIndices) {
final Map<Allele, List<GATKRead>> readsByAllele = likelihoods.alleles().stream()
.collect(Collectors.toMap(a -> a, a -> new ArrayList<>()));

Expand All @@ -85,15 +87,17 @@ public static List<List<List<Integer>>> getReadVectors(final VariantContext vc,
.forEach(ba -> ba.evidence.getReads().forEach(read -> bestHaplotypes.put(read, ba.allele)));

return vc.getAlleles().stream()
.map(allele -> readsByAllele.get(allele).stream().map(read -> featurize(read, vc, bestHaplotypes, mutect3DatasetMode)).collect(Collectors.toList()))
.map(allele -> readsByAllele.get(allele).stream().map(read -> featurize(read, vc, bestHaplotypes, mutect3DatasetMode, readGroupIndices)).collect(Collectors.toList()))
.collect(Collectors.toList());
}


private static List<Integer> featurize(final GATKRead read, final VariantContext vc,
final Map<GATKRead, Haplotype> bestHaplotypes,
final M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode) {
final M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode,
final Map<String, Integer> readGroupIndices) {
final List<Integer> result = new ArrayList<>();
result.add(readGroupIndices.get(read.getReadGroup())); // this is read group metadata rather than part of the tensor
result.add(read.getMappingQuality());
result.add(BaseQuality.getBaseQuality(read, vc).orElse(DEFAULT_BASE_QUALITY));
result.add(read.isFirstOfPair() ? 1 : 0);
Expand Down Expand Up @@ -190,7 +194,8 @@ private static List<Integer> featurize(final GATKRead read, final VariantContext
}
}
}
Utils.validate(result.size() == mutect3DatasetMode.getNumReadFeatures(), "Wrong number of features");
// the +1 is for the read group index that comes before the features
Utils.validate(result.size() == mutect3DatasetMode.getNumReadFeatures() + 1, "Wrong number of features");

return result;
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.broadinstitute.hellbender.tools.walkers.mutect;

import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.variant.variantcontext.writer.VariantContextWriter;
import org.broadinstitute.barclay.argparser.Argument;
import org.broadinstitute.barclay.argparser.ArgumentCollection;
Expand Down Expand Up @@ -262,7 +263,8 @@ public boolean shouldTrackPileupsForAssemblyRegions() {
@Override
public void onTraversalStart() {
VariantAnnotatorEngine annotatorEngine = new VariantAnnotatorEngine(makeVariantAnnotations(), null, Collections.emptyList(), false, false);
m2Engine = new Mutect2Engine(MTAC, assemblyRegionArgs, createOutputBamIndex, createOutputBamMD5, getHeaderForReads(), referenceArguments.getReferenceSpecifier(), annotatorEngine);
m2Engine = new Mutect2Engine(MTAC, assemblyRegionArgs, createOutputBamIndex, createOutputBamMD5, getHeaderForReads(),
getBestAvailableSequenceDictionary(), referenceArguments.getReferenceSpecifier(), annotatorEngine);
vcfWriter = createVCFWriter(outputVCF);
if (m2Engine.emitReferenceConfidence()) {
logger.warn("Note that the Mutect2 reference confidence mode is in BETA -- the likelihoods model and output format are subject to change in subsequent versions.");
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.broadinstitute.hellbender.tools.walkers.mutect;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.Genotype;
Expand Down Expand Up @@ -94,6 +95,7 @@ public final class Mutect2Engine implements AssemblyRegionEvaluator, AutoCloseab

final private M2ArgumentCollection MTAC;
private SAMFileHeader header;
private SAMSequenceDictionary sequenceDictionary;
private final int minCallableDepth;
public static final String CALLABLE_SITES_NAME = "callable";

Expand Down Expand Up @@ -136,9 +138,12 @@ public final class Mutect2Engine implements AssemblyRegionEvaluator, AutoCloseab
* @param referenceSpec reference specifier for the reference
* @param annotatorEngine annotator engine built with desired annotations
*/
public Mutect2Engine(final M2ArgumentCollection MTAC, AssemblyRegionArgumentCollection assemblyRegionArgs, final boolean createBamOutIndex, final boolean createBamOutMD5, final SAMFileHeader header, final GATKPath referenceSpec, final VariantAnnotatorEngine annotatorEngine) {
public Mutect2Engine(final M2ArgumentCollection MTAC, AssemblyRegionArgumentCollection assemblyRegionArgs,
final boolean createBamOutIndex, final boolean createBamOutMD5, final SAMFileHeader header,
final SAMSequenceDictionary sequenceDictionary, final GATKPath referenceSpec, final VariantAnnotatorEngine annotatorEngine) {
this.MTAC = Utils.nonNull(MTAC);
this.header = Utils.nonNull(header);
this.sequenceDictionary = sequenceDictionary;
minCallableDepth = MTAC.callableDepth;
referenceReader = ReferenceUtils.createReferenceReader(Utils.nonNull(referenceSpec));
aligner = SmithWatermanAligner.getAligner(MTAC.smithWatermanImplementation);
Expand All @@ -162,7 +167,7 @@ public Mutect2Engine(final M2ArgumentCollection MTAC, AssemblyRegionArgumentColl
annotationEngine = Utils.nonNull(annotatorEngine);
assemblyEngine = MTAC.createReadThreadingAssembler();
likelihoodCalculationEngine = AssemblyBasedCallerUtils.createLikelihoodCalculationEngine(MTAC.likelihoodArgs, MTAC.fbargs, true, MTAC.likelihoodArgs.likelihoodEngineImplementation);
genotypingEngine = new SomaticGenotypingEngine(MTAC, normalSamples, annotationEngine);
genotypingEngine = new SomaticGenotypingEngine(MTAC, normalSamples, annotationEngine, header, sequenceDictionary);
haplotypeBAMWriter = AssemblyBasedCallerUtils.createBamWriter(MTAC, createBamOutIndex, createBamOutMD5, header);
trimmer = new AssemblyRegionTrimmer(assemblyRegionArgs, header.getSequenceDictionary());

Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package org.broadinstitute.hellbender.tools.walkers.mutect;

import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMReadGroupRecord;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.SAMSequenceRecord;
import htsjdk.variant.variantcontext.Allele;
import htsjdk.variant.variantcontext.Genotype;
import htsjdk.variant.variantcontext.VariantContext;
Expand Down Expand Up @@ -43,6 +47,10 @@ private enum Label {
ARTIFACT, VARIANT, UNLABELED, IGNORE
}

private final SAMSequenceDictionary sequenceDictionary;

private final Map<String, Integer> readGroupIndices = new HashMap<>();

// number of additional variant features for assembly complexity, reference context
private static final int NUM_EXTRA_FEATURES = 9;

Expand All @@ -65,6 +73,8 @@ private enum Label {
private static final int MIN_REF = 5;

private final PrintWriter printWriter;
private final PrintWriter contigPrintWriter;
private final PrintWriter readGroupPrintWriter;

// number of nonartifact data to keep for each artifact datum
private final int nonArtifactPerArtifact;
Expand All @@ -79,9 +89,15 @@ private enum Label {
private final EnumMap<VariantType, ArrayBlockingQueue<Integer>> unmatchedArtifactAltCounts;


public Mutect3DatasetEngine(final File datasetFile, final boolean trainingMode, final int maxRefCount, final int maxAltCount, final int nonArtifactPerArtifact, final Set<String> normalSamples) {
public Mutect3DatasetEngine(final File datasetFile, final boolean trainingMode, final int maxRefCount,
final int maxAltCount, final int nonArtifactPerArtifact, final Set<String> normalSamples,
final SAMFileHeader header, final SAMSequenceDictionary sequenceDictionary) {
try {
printWriter = new PrintWriter(new FileWriter(Utils.nonNull(datasetFile)));
final File contigTableFile = datasetFile.toPath().resolveSibling("contigs.table").toFile();
final File readGroupTableFile = datasetFile.toPath().resolveSibling("read-groups.table").toFile();
contigPrintWriter = new PrintWriter(new FileWriter(contigTableFile));
readGroupPrintWriter = new PrintWriter(new FileWriter(readGroupTableFile));
} catch (IOException ex) {
throw new UserException.BadInput("Could not create dataset file writer");
}
Expand All @@ -92,6 +108,12 @@ public Mutect3DatasetEngine(final File datasetFile, final boolean trainingMode,
this.maxRefCount = maxRefCount;
this.maxAltCount = maxAltCount;

this.sequenceDictionary = sequenceDictionary;
final List<SAMReadGroupRecord> readGroups = header.getReadGroups();
for (int n = 0; n < readGroups.size(); n++) {
readGroupIndices.put(readGroups.get(n).getReadGroupId(), n);
}

unmatchedArtifactAltCounts = new EnumMap<>(VariantType.class);
for (final VariantType type : VariantType.values()) {
unmatchedArtifactAltCounts.put(type, new ArrayBlockingQueue<>(CAPACITY));
Expand All @@ -106,7 +128,7 @@ public void addData(final ReferenceContext ref, final VariantContext vc, Optiona
final M2ArgumentCollection.Mutect3DatasetMode mutect3DatasetMode) {
final String refBases = ReferenceBases.annotate(ref, vc);
final String refAllele = vc.getReference().getBaseString();
final String contig = vc.getContig();
final int contigIndex = sequenceDictionary.getSequenceIndex(vc.getContig());
final int position = vc.getStart();
final Set<String> tumorSamples = likelihoods.samples().stream().filter(sample -> !normalSamples.contains(sample)).collect(Collectors.toSet());
final int numAlt = vc.getNAlleles() - 1;
Expand Down Expand Up @@ -204,9 +226,9 @@ public void addData(final ReferenceContext ref, final VariantContext vc, Optiona
// TODO: for now we don't really need normal reads
// note that the following use the VC's allele order, not necessarily the likelihoods' allele order
final List<List<List<Integer>>> normalReadVectorsByAllele = FeaturizedReadSets.getReadVectors(vc, normalSamples,
likelihoods, logFragmentLikelihoods, maxRefCount, maxAltCount, mutect3DatasetMode);
likelihoods, logFragmentLikelihoods, maxRefCount, maxAltCount, mutect3DatasetMode, readGroupIndices);
final List<List<List<Integer>>> tumorReadVectorsByAllele = FeaturizedReadSets.getReadVectors(vc, tumorSamples,
likelihoods, logFragmentLikelihoods, maxRefCount, maxAltCount, altDownsampleMap, mutect3DatasetMode);
likelihoods, logFragmentLikelihoods, maxRefCount, maxAltCount, altDownsampleMap, mutect3DatasetMode, readGroupIndices);

// ref and alt reads have already been downsampled by the read featurizer
final List<List<Integer>> tumorRefReads = tumorReadVectorsByAllele.get(0);
Expand All @@ -227,7 +249,7 @@ public void addData(final ReferenceContext ref, final VariantContext vc, Optiona
final List<List<Integer>> normalAltReads = normalReadVectorsByAllele.get(n+1);

printWriter.println(labels.get(n).toString());
printWriter.printf("%s:%d,%s->%s%n", contig, position, refAllele, altAllele);
printWriter.printf("%d:%d,%s->%s%n", contigIndex, position, refAllele, altAllele);
printWriter.println(refBases);
printWriter.println(numberString(variantFeatureVector, "%.2f", " "));
//printWriter.printf("%d %d %d %d%n", tumorRefReads.size(), tumorAltReads.size(), normalRefReads.size(), normalAltReads.size());
Expand Down Expand Up @@ -327,5 +349,16 @@ private int[] sumADsOverSamples(final VariantContext vc, final Set<String> sampl
@Override
public void close() {
printWriter.close();

for (final SAMSequenceRecord contigRecord : sequenceDictionary.getSequences()) {
contigPrintWriter.println(String.format("%s\t%d", contigRecord.getContig(), contigRecord.getSequenceIndex()));
}

for (final Map.Entry<String, Integer> entry : readGroupIndices.entrySet()) {
readGroupPrintWriter.println(String.format("%s\t%d", entry.getKey(), entry.getValue()));
}

contigPrintWriter.close();
readGroupPrintWriter.close();
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
import com.google.common.collect.ImmutableMap;
import com.google.common.primitives.Doubles;
import htsjdk.samtools.SAMFileHeader;
import htsjdk.samtools.SAMSequenceDictionary;
import htsjdk.samtools.util.Locatable;
import htsjdk.variant.variantcontext.*;
import htsjdk.variant.vcf.VCFConstants;
Expand Down Expand Up @@ -52,7 +53,9 @@ public class SomaticGenotypingEngine implements AutoCloseable {
private final double refPseudocount = 1;
private final double altPseudocount;

public SomaticGenotypingEngine(final M2ArgumentCollection MTAC, final Set<String> normalSamples, final VariantAnnotatorEngine annotationEngine) {
public SomaticGenotypingEngine(final M2ArgumentCollection MTAC, final Set<String> normalSamples,
final VariantAnnotatorEngine annotationEngine,
final SAMFileHeader header, final SAMSequenceDictionary sequenceDictionary) {
this.MTAC = MTAC;
altPseudocount = MTAC.minAF == 0.0 ? 1 : 1 - Math.log(2)/Math.log(MTAC.minAF);

Expand All @@ -62,7 +65,7 @@ public SomaticGenotypingEngine(final M2ArgumentCollection MTAC, final Set<String

mutect3DatasetEngine = MTAC.mutect3Dataset == null ? Optional.empty() :
Optional.of(new Mutect3DatasetEngine(MTAC.mutect3Dataset, MTAC.mutect3TrainingDataMode, MTAC.maxRefCountForMutect3,
MTAC.maxAltCountForMutect3, MTAC.mutect3NonArtifactRatio, normalSamples));
MTAC.maxAltCountForMutect3, MTAC.mutect3NonArtifactRatio, normalSamples, header, sequenceDictionary));
Utils.validateArg(!(MTAC.mutect3Dataset == null && MTAC.mutect3TrainingDataMode), "No dataset file specified for Mutect3 training data mode.");
}

Expand Down

0 comments on commit 2a420e4

Please sign in to comment.