/*
 * Decompiled with CFR 0.152.
 */
package imputationtool.postprocessing;

import JSci.maths.ArrayMath;
import imputationtool.postprocessing.ScatterPlot;
import java.io.IOException;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.regex.Pattern;
import umcg.genetica.io.Gpio;
import umcg.genetica.io.text.TextFile;
import umcg.genetica.io.trityper.SNP;
import umcg.genetica.io.trityper.SNPLoader;
import umcg.genetica.io.trityper.TriTyperGenotypeData;

public class TriTyperDatasetCorrelator {
    int nrThreads = 2;
    private TriTyperGenotypeData ggDataset1;
    private TriTyperGenotypeData ggDataset2;
    private TextFile log;
    private HashSet<String> confineSNPList;
    private HashMap<String, ArrayList<Double>> beagleR2;
    private int[] beaglecorrelationfreqdistribution;

    public TriTyperDatasetCorrelator(String dataset1, String dataset1Name, String dataset2, String dataset2Name) throws IOException {
        System.out.println("Running with " + this.nrThreads + " threads.");
        this.ggDataset1 = new TriTyperGenotypeData();
        this.ggDataset1.load(dataset1);
        this.ggDataset2 = new TriTyperGenotypeData();
        this.ggDataset2.load(dataset2);
    }

    public TriTyperDatasetCorrelator(String dataset1, String dataset1Name, String dataset2, String dataset2Name, String beagleInput, String template, Integer numBatches) throws IOException {
        System.out.println("Running with " + this.nrThreads + " threads.");
        this.loadBeagleR2(beagleInput, template, numBatches);
        this.ggDataset1 = new TriTyperGenotypeData();
        this.ggDataset1.load(dataset1);
        this.ggDataset2 = new TriTyperGenotypeData();
        this.ggDataset2.load(dataset2);
    }

    public double[] getGenotypes(SNP snp, TriTyperGenotypeData gg, int[] inds) {
        double[] tmpGenotypes = new double[inds.length];
        double[] dosage = snp.getDosageValues();
        byte[] genotypes = snp.getGenotypes();
        if (dosage == null) {
            int g = 0;
            for (int i = 0; i < inds.length; ++i) {
                int indId = inds[i];
                tmpGenotypes[g] = genotypes[indId];
                ++g;
            }
        } else {
            int g = 0;
            for (int i = 0; i < inds.length; ++i) {
                int indId = inds[i];
                tmpGenotypes[g] = dosage[indId];
                ++g;
            }
        }
        return tmpGenotypes;
    }

    public void run(String outputLocation) throws IOException {
        int u;
        byte[] complementAllele = new byte[256];
        complementAllele[0] = 0;
        complementAllele[84] = 65;
        complementAllele[65] = 84;
        complementAllele[48] = 48;
        complementAllele[67] = 71;
        complementAllele[71] = 67;
        int[] alleleIndex = new int[256];
        for (int a = 0; a < 256; ++a) {
            alleleIndex[a] = 4;
        }
        alleleIndex[65] = 0;
        alleleIndex[67] = 1;
        alleleIndex[71] = 2;
        alleleIndex[84] = 3;
        this.log = new TextFile(outputLocation + "/correlationOutput.txt", true);
        this.determineUniqueSNPS();
        int numInd1 = this.ggDataset1.getIndividuals().length;
        int numInd2 = this.ggDataset2.getIndividuals().length;
        int[] inds1 = new int[numInd1];
        int[] inds2 = new int[numInd1];
        HashMap ind1ToInd2 = new HashMap();
        int numSharedIndividuals = 0;
        for (int i = 0; i < numInd1; ++i) {
            String indName1 = this.ggDataset1.getIndividuals()[i];
            if (this.ggDataset1.getIsIncluded()[i].booleanValue()) {
                Integer indId2 = this.ggDataset2.getIndividualId(indName1);
                if (indId2 != null) {
                    if (this.ggDataset2.getIsIncluded()[indId2].booleanValue()) {
                        inds1[i] = i;
                        inds2[i] = indId2;
                        ++numSharedIndividuals;
                        continue;
                    }
                    inds1[i] = -1;
                    inds2[i] = -1;
                    continue;
                }
                inds1[i] = -1;
                inds2[i] = -1;
                continue;
            }
            inds1[i] = -1;
            inds2[i] = -1;
        }
        System.out.println("Shared samples: " + numSharedIndividuals);
        if (numSharedIndividuals == 0) {
            System.exit(-1);
        }
        int[] inds1Final = new int[numSharedIndividuals];
        int[] inds2Final = new int[numSharedIndividuals];
        int counter = 0;
        for (int i = 0; i < numInd1; ++i) {
            if (inds1[i] == -1 || inds2[i] == -1) continue;
            System.out.println(counter + "\t" + inds1[i] + "\t" + inds2[i]);
            inds1Final[counter] = inds1[i];
            inds2Final[counter] = inds2[i];
            ++counter;
        }
        counter = 0;
        int below80 = 0;
        int belowSquared80 = 0;
        int flipped = 0;
        String[] snps = this.ggDataset1.getSNPs();
        int[] correlationfreqdistribution = new int[11];
        int[] correlationdifferencedistribution = new int[11];
        int prevInt = 0;
        int corNo = 0;
        ScatterPlot s = null;
        if (this.beagleR2 != null) {
            s = new ScatterPlot(1000);
        }
        System.out.println("Mapping SNPs");
        SNPLoader loader1 = this.ggDataset1.createSNPLoader();
        SNPLoader loader2 = this.ggDataset2.createSNPLoader();
        for (int snp1 = 0; snp1 < snps.length; ++snp1) {
            Integer snp2 = this.ggDataset2.getSnpToSNPId().get((Object)snps[snp1]);
            if (snp2 == null) continue;
            boolean takeComplement = false;
            String excludeReason = "";
            boolean exclude = false;
            SNP snp1Object = this.ggDataset1.getSNPObject(snp1);
            SNP snp2Object = this.ggDataset2.getSNPObject(snp2.intValue());
            loader1.loadGenotypes(snp1Object);
            loader2.loadGenotypes(snp2Object);
            double[] genotypes1 = this.getGenotypes(snp1Object, this.ggDataset1, inds1Final);
            double[] genotypes2 = this.getGenotypes(snp2Object, this.ggDataset2, inds2Final);
            int missingGenotypes = 0;
            for (int g = 0; g < genotypes1.length; ++g) {
                if (genotypes1[g] != -1.0 && genotypes2[g] != -1.0) continue;
                genotypes1[g] = -1.0;
                genotypes2[g] = -1.0;
                ++missingGenotypes;
            }
            if ((double)missingGenotypes / (double)genotypes1.length > 0.1) {
                exclude = true;
                excludeReason = "SNP has low callrate (> 10%): Missing: " + missingGenotypes + " / " + genotypes1.length;
            }
            double maf1 = snp1Object.getMAF();
            double maf2 = snp2Object.getMAF();
            if (maf1 < 0.05 || maf2 < 0.05) {
                exclude = true;
                excludeReason = excludeReason + "\tMAF < 0.05: " + maf1 + "\t" + maf2;
            }
            if (this.ggDataset1.getChr(snp1) != null) {
                byte chr1 = this.ggDataset1.getChr(snp1);
                int pos1 = this.ggDataset1.getChrPos(snp1);
                byte chr2 = this.ggDataset2.getChr(snp2.intValue());
                int pos2 = this.ggDataset2.getChrPos(snp2.intValue());
                if (chr1 != chr2 || pos1 != pos2) {
                    exclude = true;
                    excludeReason = excludeReason + "\tSNPs map to different positions";
                }
            }
            byte[] allelesbytes = snp1Object.getAlleles();
            String alleles2 = new String(snp2Object.getAlleles());
            String alleles1 = null;
            try {
                alleles1 = new String(snp1Object.getAlleles(), "UTF-8");
            }
            catch (Exception e) {
                // empty catch block
            }
            boolean allelesOk = true;
            if (allelesbytes[0] == 0 || allelesbytes[1] == 0) {
                exclude = true;
            }
            if (alleles1 == null) {
                exclude = true;
                excludeReason = excludeReason + " SNPs has null alleles";
            }
            boolean strandForward = true;
            int[] alleleIndex1 = new int[5];
            for (int ind = 0; ind < this.ggDataset2.getIndividuals().length; ++ind) {
                if (!this.ggDataset2.getIsIncluded()[ind].booleanValue()) continue;
                byte[] snpallele1 = snp2Object.getAllele1();
                byte[] snpallele2 = snp2Object.getAllele2();
                byte allele1Byte = snpallele1[ind];
                int n = alleleIndex[allele1Byte];
                alleleIndex1[n] = alleleIndex1[n] + 1;
                byte allele2Byte = snpallele2[ind];
                int n2 = alleleIndex[allele2Byte];
                alleleIndex1[n2] = alleleIndex1[n2] + 1;
            }
            int[] alleleIndex2 = new int[5];
            for (int ind = 0; ind < this.ggDataset1.getIndividuals().length; ++ind) {
                if (!this.ggDataset1.getIsIncluded()[ind].booleanValue()) continue;
                byte[] snpallele1 = snp1Object.getAllele1();
                byte[] snpallele2 = snp1Object.getAllele2();
                byte allele1Byte = snpallele1[ind];
                int n = alleleIndex[allele1Byte];
                alleleIndex2[n] = alleleIndex2[n] + 1;
                byte allele2Byte = snpallele2[ind];
                int n3 = alleleIndex[allele2Byte];
                alleleIndex2[n3] = alleleIndex2[n3] + 1;
            }
            double[] alleleIndexFreq1 = new double[4];
            double[] alleleIndexFreq2 = new double[4];
            int itr = 0;
            boolean issueResolved = false;
            while (!issueResolved) {
                int a;
                if (!strandForward) {
                    int[] alleleIndex1Copy = new int[4];
                    System.arraycopy(alleleIndex1, 0, alleleIndex1Copy, 0, 4);
                    alleleIndex1[0] = alleleIndex1Copy[3];
                    alleleIndex1[1] = alleleIndex1Copy[2];
                    alleleIndex1[2] = alleleIndex1Copy[1];
                    alleleIndex1[3] = alleleIndex1Copy[0];
                }
                int totalCalled1 = 0;
                int totalCalled2 = 0;
                for (a = 0; a < 4; ++a) {
                    totalCalled1 += alleleIndex1[a];
                    totalCalled2 += alleleIndex2[a];
                }
                for (a = 0; a < 4; ++a) {
                    alleleIndexFreq1[a] = (double)alleleIndex1[a] / (double)totalCalled1;
                    alleleIndexFreq2[a] = (double)alleleIndex2[a] / (double)totalCalled2;
                }
                int nrDifferentAllelesPresent = 0;
                for (int a2 = 0; a2 < 4; ++a2) {
                    if (!(alleleIndexFreq1[a2] > 0.0) && !(alleleIndexFreq2[a2] > 0.0)) continue;
                    ++nrDifferentAllelesPresent;
                }
                if (nrDifferentAllelesPresent > 2) {
                    strandForward = !strandForward;
                } else {
                    issueResolved = true;
                }
                if (++itr < 2 || issueResolved) continue;
                exclude = true;
                issueResolved = true;
                excludeReason = excludeReason + "\tIncompatibleAlleles:Dataset=" + alleles2 + ",HapMap=" + alleles1;
            }
            takeComplement = !strandForward;
            boolean concordant = true;
            for (int a = 0; a < 4; ++a) {
                if (!(alleleIndexFreq1[a] > 0.0) || !(alleleIndexFreq2[a] > 0.0)) continue;
                if (alleleIndexFreq1[a] > 0.5 && alleleIndexFreq2[a] < 0.5) {
                    concordant = false;
                }
                if (!(alleleIndexFreq1[a] < 0.5) || !(alleleIndexFreq2[a] > 0.5)) continue;
                concordant = false;
            }
            byte[] snpAlleles = snp2Object.getAlleles();
            if (!(snpAlleles[0] + snpAlleles[1] != 149 && snpAlleles[0] + snpAlleles[1] != 138 || concordant)) {
                takeComplement = !takeComplement;
                concordant = true;
            }
            if (exclude && takeComplement) {
                System.out.println(snp1Object.getName() + "\t" + excludeReason);
            } else if (!exclude) {
                int binNumber;
                double[] finalGenotypes1 = null;
                double[] finalGenotypes2 = null;
                if (missingGenotypes > 0) {
                    finalGenotypes1 = new double[genotypes1.length - missingGenotypes];
                    finalGenotypes2 = new double[genotypes1.length - missingGenotypes];
                    int ctr = 0;
                    for (int i = 0; i < genotypes1.length; ++i) {
                        if (genotypes1[i] == -1.0) continue;
                        finalGenotypes1[ctr] = genotypes1[i];
                        finalGenotypes2[ctr] = genotypes2[i];
                        ++ctr;
                    }
                } else {
                    finalGenotypes1 = genotypes1;
                    finalGenotypes2 = genotypes2;
                }
                double correlation = ArrayMath.correlation((double[])finalGenotypes1, (double[])finalGenotypes2);
                double absCor = Math.abs(correlation);
                double absCorSquared = absCor * absCor;
                if (this.beagleR2 != null) {
                    ArrayList<Double> r2s = this.beagleR2.get(snp1Object.getName());
                    double m = 0.0;
                    for (int i = 0; i < r2s.size(); ++i) {
                        m += r2s.get(i).doubleValue();
                    }
                    s.plot(absCorSquared, m /= (double)r2s.size());
                }
                int n = binNumber = (int)(absCorSquared * 10.0);
                correlationfreqdistribution[n] = correlationfreqdistribution[n] + 1;
                if (absCor <= 0.8) {
                    ++below80;
                }
                if (absCorSquared <= 0.8) {
                    ++belowSquared80;
                }
                if (correlation < 0.0) {
                    ++flipped;
                }
                snp1Object = null;
                snp2Object = null;
                ++corNo;
                if (++counter % 10000 == 0 && counter > prevInt) {
                    System.out.println(counter + " SNPS processed\t" + below80 + "\t" + 100.0 * (double)below80 / (double)counter + "% R <= 0.80.\ts" + 100.0 * (double)belowSquared80 / (double)counter + "% R2 <= 0.80.\t" + flipped + " - " + 100.0 * (double)flipped / (double)counter + " flipped");
                    prevInt = counter;
                }
            }
            snp1Object.clearGenotypes();
            snp2Object.clearGenotypes();
        }
        if (this.beagleR2 != null) {
            s.draw(outputLocation + "/CorrelationVsBeagleR2.png");
        }
        System.out.println(counter + " SNPS processed\t" + below80 + "\t" + 100.0 * (double)below80 / (double)counter + "% R <= 0.80.\ts" + 100.0 * (double)belowSquared80 / (double)counter + "% R2 <= 0.80.\t" + flipped + " - " + 100.0 * (double)flipped / (double)counter + " flipped");
        prevInt = counter;
        this.log.writeln("Distribution of correlations (counts, R2):");
        int total = 0;
        for (int u2 = 0; u2 < correlationfreqdistribution.length; ++u2) {
            this.log.writeln(u2 + "\t" + correlationfreqdistribution[u2] + "\t" + (total += correlationfreqdistribution[u2]));
        }
        double sum = 0.0;
        this.log.writeln("Distribution of correlations (frequency, R2):");
        for (u = 0; u < correlationfreqdistribution.length; ++u) {
            double freq = (double)correlationfreqdistribution[u] / (double)total;
            this.log.writeln(u + "\t" + freq + "\t" + (sum += freq));
        }
        if (this.beaglecorrelationfreqdistribution != null) {
            this.log.writeln("Beagle Distribution of correlations (R2):");
            total = 0;
            for (u = 0; u < this.beaglecorrelationfreqdistribution.length; ++u) {
                total += this.beaglecorrelationfreqdistribution[u];
            }
            sum = 0.0;
            for (u = 0; u < this.beaglecorrelationfreqdistribution.length; ++u) {
                double freq = (double)this.beaglecorrelationfreqdistribution[u] / (double)total;
                this.log.writeln(u + "\t" + freq + "\t" + (sum += freq));
            }
        }
        this.log.close();
    }

    private void loadBeagleR2(String inputDir, String template, Integer numBatches) throws IOException {
        String fileName;
        String templatecopy;
        int chr;
        String batchName;
        int b;
        String[] batchNames = this.getBatches(numBatches);
        this.beagleR2 = new HashMap();
        boolean allFilesAvailable = true;
        for (b = 0; b < numBatches; ++b) {
            batchName = batchNames[b];
            for (chr = 1; chr <= 22; ++chr) {
                templatecopy = new String(template);
                templatecopy = templatecopy.replace("BATCH", batchName);
                templatecopy = templatecopy.replace("CHROMOSOME", "" + chr);
                fileName = inputDir + "/" + templatecopy + ".r2";
                if (Gpio.canRead((String)fileName)) continue;
                System.out.println("Cannot open file:\t" + fileName);
                allFilesAvailable = false;
            }
        }
        if (allFilesAvailable) {
            this.beaglecorrelationfreqdistribution = new int[11];
            for (b = 0; b < numBatches; ++b) {
                batchName = batchNames[b];
                for (chr = 1; chr <= 22; ++chr) {
                    templatecopy = new String(template);
                    templatecopy = templatecopy.replace("BATCH", batchName);
                    templatecopy = templatecopy.replace("CHROMOSOME", "" + chr);
                    fileName = inputDir + "/" + templatecopy + ".r2";
                    TextFile in = new TextFile(fileName, false);
                    String line = "";
                    Pattern tab = Pattern.compile("\t");
                    while ((line = in.readLine()) != null) {
                        String[] elems = tab.split(line);
                        Double r2 = null;
                        if (elems.length != 2) continue;
                        try {
                            int binNumber;
                            r2 = Double.parseDouble(elems[1]);
                            if (r2.isNaN()) {
                                r2 = 0.0;
                            }
                            int n = binNumber = (int)(r2 * 10.0);
                            this.beaglecorrelationfreqdistribution[n] = this.beaglecorrelationfreqdistribution[n] + 1;
                            ArrayList<Double> r2s = this.beagleR2.get(elems[0]);
                            if (r2s == null) {
                                r2s = new ArrayList();
                            }
                            r2s.add(r2);
                            this.beagleR2.put(elems[0], r2s);
                        }
                        catch (NumberFormatException e) {
                            e.printStackTrace();
                        }
                    }
                    in.close();
                }
            }
        }
    }

    private void getRandomSNPs(int num, String output) throws IOException {
        String[] snps1 = this.ggDataset1.getSNPs();
        String[] snps2 = this.ggDataset2.getSNPs();
        boolean counter = false;
        ArrayList<String> snps = new ArrayList<String>();
        for (int i = 0; i < snps1.length; ++i) {
            Integer snp2Id = this.ggDataset2.getSnpToSNPId().get((Object)snps1[i]);
            if (snp2Id == null) continue;
            snps.add(snps1[i]);
        }
        TextFile log2 = new TextFile(output, true);
        for (int i = 0; i < num; ++i) {
            log2.writeln((String)snps.remove((int)Math.round(Math.random() * (double)snps.size())));
        }
        log2.close();
    }

    private void determineUniqueSNPS() throws IOException {
        String[] snps1 = this.ggDataset1.getSNPs();
        String[] snps2 = this.ggDataset2.getSNPs();
        int counter = 0;
        if (this.confineSNPList != null) {
            int absent1 = 0;
            int absent2 = 0;
            Iterator<String> i = this.confineSNPList.iterator();
            while (i.hasNext()) {
                if (this.ggDataset1.getSnpToSNPId().get((Object)i.next()) == -9) {
                    ++absent1;
                }
                if (this.ggDataset2.getSnpToSNPId().get((Object)i.next()) != -9) continue;
                ++absent2;
            }
            this.log.writeln(absent1 + " of the " + this.confineSNPList.size() + " SNPs in ConfineSNPList are not present in " + this.ggDataset1.getGenotypeFileName() + ", and " + absent2 + " are not present in " + this.ggDataset2.getGenotypeFileName());
        } else {
            int i;
            for (i = 0; i < snps1.length; ++i) {
                Integer snp2Id = this.ggDataset2.getSnpToSNPId().get((Object)snps1[i]);
                if (snp2Id != null) continue;
                ++counter;
            }
            this.log.writeln(counter + " of the " + snps1.length + " SNPs in " + this.ggDataset1.getGenotypeFileName() + " are not present in " + this.ggDataset2.getGenotypeFileName() + ": " + (double)counter / (double)snps1.length * 100.0 + "%");
            counter = 0;
            for (i = 0; i < snps2.length; ++i) {
                Integer snp1Id = this.ggDataset1.getSnpToSNPId().get((Object)snps2[i]);
                if (snp1Id != null) continue;
                ++counter;
            }
            this.log.writeln(counter + " of the " + snps2.length + " SNPs in " + this.ggDataset2.getGenotypeFileName() + " are not present in " + this.ggDataset1.getGenotypeFileName() + ": " + (double)counter / (double)snps2.length * 100.0 + "%");
        }
    }

    private String[] getBatches(int numBatches) {
        String[] batches = new String[numBatches];
        String[] alphabet = new String[]{"a", "b", "c", "d", "e", "f", "g", "h", "i", "j", "k", "l", "m", "n", "o", "p", "q", "r", "s", "t", "u", "v", "w", "x", "y", "z"};
        String firstletter = "a";
        int alphacounter = 1;
        int betacounter = 0;
        for (int i = 0; i < numBatches; ++i) {
            if (i == 25) {
                firstletter = alphabet[alphacounter];
                ++alphacounter;
                betacounter = 0;
            }
            batches[i] = firstletter + alphabet[betacounter];
            ++betacounter;
        }
        return batches;
    }

    public void confineToSNPs(String snpList) throws IOException {
        this.confineSNPList = new HashSet();
        TextFile tf = new TextFile(snpList, false);
        this.confineSNPList.addAll(tf.readAsArrayList());
        tf.close();
    }

    private HashMap<Integer, Integer> determineSampleIDMap() {
        String[] ds1Samples = this.ggDataset1.getIndividuals();
        String[] ds2Samples = this.ggDataset2.getIndividuals();
        HashMap<Integer, Integer> samplemap = new HashMap<Integer, Integer>();
        int numShared = 0;
        for (int i = 0; i < ds1Samples.length; ++i) {
            String sample1 = ds1Samples[i];
            for (int j = 0; j < ds2Samples.length; ++j) {
                String sample2 = ds2Samples[j];
                if (!sample1.equals(sample2)) continue;
                samplemap.put(i, j);
                ++numShared;
            }
        }
        System.out.println("Shared samples between datasets:" + numShared);
        if (numShared > 0) {
            return samplemap;
        }
        return null;
    }
}

