/*
 * Decompiled with CFR 0.152.
 */
package jebl.evolution.trees;

import java.io.BufferedReader;
import java.io.File;
import java.io.FileNotFoundException;
import java.io.FileReader;
import java.io.IOException;
import java.io.PrintWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import jebl.evolution.graphs.Node;
import jebl.evolution.io.ImportException;
import jebl.evolution.io.NexusImporter;
import jebl.evolution.trees.RootedTree;
import jebl.evolution.trees.RootedTreeUtils;

/*
 * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
 */
public class CalculateSplitRates {
    List<RootedTree> treeList;
    NexusImporter importer;
    public final String INDICATOR = "changed";
    public final String RATE = "rate";
    private List<Clade> cladeList;
    private List<List<TimeInterval>> intervalList;
    private DensityMap densityMap;
    public int numRateBoxes = 25;
    public int numTimeBoxes = 100;
    public double edgeFraction = 0.05;

    public CalculateSplitRates(NexusImporter importer) {
        this.importer = importer;
        this.treeList = new ArrayList<RootedTree>(100);
        this.cladeList = new ArrayList<Clade>(100);
        this.intervalList = new ArrayList<List<TimeInterval>>(100);
        this.densityMap = new DensityMap(70, 20, 0.0, 0.0, 70.0, 0.01);
    }

    public void loadTrees(int maxTrees, int burnIn) throws IOException, ImportException {
        int cnt = 0;
        int cntBurnin = 0;
        while (this.importer.hasTree() && cnt < maxTrees) {
            RootedTree tree = (RootedTree)this.importer.importNextTree();
            if (cntBurnin > burnIn) {
                this.treeList.add(tree);
                ++cnt;
            }
            ++cntBurnin;
        }
    }

    private DensityMap createDensityMap(int numRateBoxes, int numTimeBoxes) {
        double maxTreeHeight = 0.0;
        double minRate = 1.0;
        double maxRate = 0.0;
        for (RootedTree tree : this.treeList) {
            double thisHeight = tree.getHeight(tree.getRootNode());
            if (thisHeight > maxTreeHeight) {
                maxTreeHeight = thisHeight;
            }
            Set<Node> nodeList = tree.getNodes();
            for (Node node : nodeList) {
                if (node == tree.getRootNode()) continue;
                double rate = this.getRate(node);
                if (rate < minRate) {
                    minRate = rate;
                }
                if (!(rate > maxRate)) continue;
                maxRate = rate;
            }
        }
        System.err.println("maxTreeHeight = " + maxTreeHeight);
        System.err.println("minRate = " + minRate);
        System.err.println("maxRate = " + maxRate);
        maxTreeHeight *= 1.0 + this.edgeFraction;
        double rateSpread = maxRate - minRate;
        minRate -= rateSpread * this.edgeFraction;
        if (minRate < 0.0) {
            minRate = 0.0;
        }
        System.out.println("real max = " + maxRate);
        System.out.println("new  max = " + (maxRate += rateSpread * this.edgeFraction));
        DensityMap densityMap = new DensityMap(numTimeBoxes, numRateBoxes, 0.0, minRate, maxTreeHeight, maxRate);
        for (RootedTree tree : this.treeList) {
            this.addTreeToDensityMap(densityMap, tree);
        }
        return densityMap;
    }

    private void displayLongestDwellTimeInfo() {
        for (RootedTree tree : this.treeList) {
            double longestDwell = this.getLongestClockDwellTime(this.getClockDwellTimes(tree));
            double treeLenght = this.getTreeLength(tree);
            System.out.printf("%5.4f\n", longestDwell / treeLenght);
        }
    }

    public void displayStatistics() {
    }

    public void addTreeToDensityMap(DensityMap densityMap, RootedTree tree) {
        Set<Node> nodeList = tree.getNodes();
        for (Node node : nodeList) {
            if (node == tree.getRootNode()) continue;
            densityMap.addTreeBranch(tree.getHeight(node), tree.getHeight(tree.getParent(node)), this.getRate(node));
        }
    }

    public void displayIntervals() {
        if (this.intervalList.size() == 0) {
            return;
        }
        System.out.println("Interval counts:");
        for (List<TimeInterval> timesList : this.intervalList) {
            System.out.println("\t" + timesList.size() + " " + this.getLongestInterval(timesList));
        }
    }

    public void findTimeIntervals() {
    }

    private double getRate(Node node) {
        Double rateDouble = (Double)node.getAttribute("rate");
        return rateDouble;
    }

    private boolean rateChanged(Node node) {
        Integer changedInt = (Integer)node.getAttribute("changed");
        return changedInt == 1;
    }

    private List<TimeInterval> getTimeIntervals(RootedTree tree) {
        return this.getTimeIntervals(tree, tree.getRootNode(), tree.getHeight(tree.getRootNode()), new ArrayList<TimeInterval>());
    }

    private double getLongestClockDwellTime(Map<Double, Double> dwellTimes) {
        double time = 0.0;
        Set<Double> rates = dwellTimes.keySet();
        for (Double rate : rates) {
            Double dwell = dwellTimes.get(rate);
            if (!(dwell > time)) continue;
            time = dwell;
        }
        return time;
    }

    private Map<Double, Double> getClockDwellTimes(RootedTree tree) {
        LinkedHashMap<Double, Double> rateDwellTimes = new LinkedHashMap<Double, Double>();
        Set<Node> nodes = tree.getNodes();
        for (Node node : nodes) {
            if (node == tree.getRootNode()) continue;
            double branchLength = tree.getLength(node);
            double rate = this.getRate(node);
            Double thisRate = new Double(rate);
            Double dwellTime = (Double)rateDwellTimes.get(thisRate);
            if (dwellTime == null) {
                rateDwellTimes.put(thisRate, new Double(branchLength));
                continue;
            }
            rateDwellTimes.put(thisRate, new Double(dwellTime + branchLength));
        }
        return rateDwellTimes;
    }

    private double getTreeLength(RootedTree tree) {
        double total = 0.0;
        Set<Node> nodes = tree.getNodes();
        for (Node node : nodes) {
            total += tree.getLength(node);
        }
        return total;
    }

    private List<TimeInterval> getTimeIntervals(RootedTree tree, Node node, double startTime, List<TimeInterval> intervals) {
        if (tree.isExternal(node)) {
            TimeInterval timeInterval = new TimeInterval(startTime, tree.getHeight(node), this.getRate(node));
            intervals.add(timeInterval);
            return null;
        }
        List<Node> children = tree.getChildren(node);
        for (Node child : children) {
            if (this.rateChanged(child)) {
                TimeInterval timeInterval = new TimeInterval(startTime, tree.getHeight(node), this.getRate(node));
                intervals.add(timeInterval);
                this.getTimeIntervals(tree, child, tree.getHeight(node), intervals);
                continue;
            }
            this.getTimeIntervals(tree, child, startTime, intervals);
        }
        return intervals;
    }

    private void addCladeRateInforamtion(RootedTree tree) {
        for (Node node : tree.getInternalNodes()) {
            this.addCladeRateInformation(tree, node);
        }
        for (Node node : tree.getExternalNodes()) {
            this.addCladeRateInformation(tree, node);
        }
    }

    private double getLongestInterval(List<TimeInterval> intervals) {
        Collections.sort(intervals);
        return intervals.get(intervals.size() - 1).getLength();
    }

    private void addCladeRateInformation(RootedTree tree, Node node) {
        if (tree.getRootNode() != node) {
            Integer changedInt = (Integer)node.getAttribute("changed");
            Double rateDouble = (Double)node.getAttribute("rate");
            String name = this.constructUniqueName(tree, node);
            Clade newClade = new Clade(name);
            int index = this.cladeList.indexOf(newClade);
            if (index == -1) {
                index = this.cladeList.size();
                this.cladeList.add(newClade);
            }
            this.cladeList.get(index).addValues(changedInt, rateDouble);
        }
    }

    private String constructUniqueName(RootedTree tree, Node node) {
        if (tree.isExternal(node)) {
            return tree.getTaxon(node).getName();
        }
        Set<Node> taxa = RootedTreeUtils.getDescendantTips(tree, node);
        ArrayList<String> nameList = new ArrayList<String>(taxa.size());
        for (Node tip : taxa) {
            nameList.add(tree.getTaxon(tip).getName());
        }
        Collections.sort(nameList);
        StringBuffer sb = new StringBuffer();
        int cnt = 0;
        for (String name : nameList) {
            if (cnt != 0) {
                sb.append(",");
            }
            sb.append(name);
            ++cnt;
        }
        return sb.toString();
    }

    public static void main(String[] args) {
        try {
            NexusImporter importer = new NexusImporter(new BufferedReader(new FileReader(new File(args[0]))));
            CalculateSplitRates calculator = new CalculateSplitRates(importer);
            calculator.loadTrees(Integer.parseInt(args[1]), Integer.parseInt(args[2]));
            PrintWriter printWriter = new PrintWriter(args[3]);
            calculator.writeLongestDwellTimeInfo(printWriter);
            printWriter.close();
            printWriter = new PrintWriter(args[4]);
            calculator.writeDensityMap(printWriter);
            printWriter.close();
        }
        catch (FileNotFoundException e) {
            e.printStackTrace();
        }
        catch (ImportException e) {
            e.printStackTrace();
        }
        catch (IOException e) {
            e.printStackTrace();
        }
        System.exit(0);
    }

    private void writeDensityMap(PrintWriter printWriter) {
        this.densityMap = this.createDensityMap(this.numRateBoxes, this.numTimeBoxes);
        printWriter.println(this.densityMap.toString());
    }

    private void writeLongestDwellTimeInfo(PrintWriter printWriter) {
        printWriter.print("DwellTime\tTreeLength\tProportion\n");
        for (RootedTree tree : this.treeList) {
            Map<Double, Double> map = this.getClockDwellTimes(tree);
            double longestDwell = this.getLongestClockDwellTime(map);
            double treeLenght = this.getTreeLength(tree);
            double proportion = longestDwell / treeLenght;
            printWriter.printf("%5.4f\t%5.4f\t%5.4f\n", longestDwell, treeLenght, proportion);
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class Clade
    implements Comparable<Clade> {
        private String name;
        private List<Integer> indicatorValues;
        private List<Double> rateValues;
        private int count = 0;

        public Clade(String name) {
            this.name = name;
            this.indicatorValues = new ArrayList<Integer>(1000);
            this.rateValues = new ArrayList<Double>(1000);
        }

        @Override
        public int compareTo(Clade clade) {
            System.out.println("co");
            return this.name.compareTo(clade.getName());
        }

        public void addValues(Integer inInt, Double inDouble) {
            this.indicatorValues.add(inInt);
            this.rateValues.add(inDouble);
            ++this.count;
        }

        public boolean equals(Object obj) {
            Clade c = (Clade)obj;
            return this.name.compareTo(c.getName()) == 0;
        }

        public String getName() {
            return this.name;
        }

        public int getCount() {
            return this.count;
        }

        public double getIndicatorProbability() {
            int sum = 0;
            for (Integer i : this.indicatorValues) {
                sum += i.intValue();
            }
            return (double)sum / (double)this.count;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class CladeFrequencyComparator
    implements Comparator<Clade> {
        private CladeFrequencyComparator() {
        }

        @Override
        public int compare(Clade cladeA, Clade cladeB) {
            if (cladeA.getCount() > cladeB.getCount()) {
                return -1;
            }
            if (cladeA.getCount() < cladeB.getCount()) {
                return 1;
            }
            return cladeA.getName().compareTo(cladeB.getName());
        }
    }

    private class DensityMap {
        private final String SEP = "\t";
        private final String DBL = "%5.4f";
        private int binX;
        private int binY;
        private int[][] data;
        private int[] counts;
        private double startX;
        private double startY;
        private double scaleX;
        private double scaleY;
        private int total = 0;

        public DensityMap(int binX, int binY, double startX, double startY, double endX, double endY) {
            this.binX = binX;
            this.binY = binY;
            this.data = new int[binX][binY];
            this.counts = new int[binX];
            this.startX = startX;
            this.startY = startY;
            this.scaleX = (endX - startX) / (double)binX;
            this.scaleY = (endY - startY) / (double)binY;
        }

        public void addTreeBranch(double start, double end, double y) {
            int Y = (int)((y - this.startY) / this.scaleY);
            int START = (int)((start - this.startX) / this.scaleX);
            int END = (int)((end - this.startX) / this.scaleX);
            int i = START;
            while (i <= END) {
                int[] nArray = this.data[i];
                int n = Y;
                nArray[n] = nArray[n] + 1;
                int n2 = i++;
                this.counts[n2] = this.counts[n2] + 1;
                ++this.total;
            }
        }

        public String toString() {
            StringBuilder sb = new StringBuilder();
            sb.append("0.0");
            int i = 0;
            while (i < this.binX) {
                sb.append("\t");
                sb.append(String.format("%7.5f", this.startX + this.scaleX * (double)i));
                ++i;
            }
            sb.append("\n");
            i = 0;
            while (i < this.binY) {
                sb.append(String.format("%7.5f", this.startY + this.scaleY * (double)i));
                int j = 0;
                while (j < this.binX) {
                    sb.append("\t");
                    double dblCounts = this.counts[j];
                    if (dblCounts > 0.0) {
                        sb.append(String.format("%5.4f", (double)this.data[j][i] / (double)this.counts[j]));
                    } else {
                        sb.append(String.format("%5.4f", 0.0));
                    }
                    ++j;
                }
                sb.append("\n");
                ++i;
            }
            return sb.toString();
        }
    }

    private class DoubleStatistic {
        private List<Double> data = new ArrayList<Double>(1000);
        private double total = 0.0;

        public void add(double d) {
            this.data.add(d);
            this.total += d;
        }

        public double getMean() {
            return this.total / (double)this.data.size();
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class DwellTime
    implements Comparable<DwellTime> {
        private double start;
        private double rate;
        private double length;

        public DwellTime(double start, double length, double rate) {
            this.start = start;
            this.length = length;
            this.rate = rate;
        }

        @Override
        public int compareTo(DwellTime dwellTime) {
            return (int)(this.getLength() - dwellTime.getLength());
        }

        public double getLength() {
            return this.length;
        }
    }

    /*
     * This class specifies class file version 49.0 but uses Java 6 signatures.  Assumed Java 6.
     */
    private class TimeInterval
    implements Comparable<TimeInterval> {
        private double start;
        private double end;
        private double rate;

        public TimeInterval(double start, double end, double rate) {
            this.start = start;
            this.end = end;
            this.rate = rate;
        }

        @Override
        public int compareTo(TimeInterval timeInterval) {
            return (int)(this.getLength() - timeInterval.getLength());
        }

        public double getLength() {
            return this.start - this.end;
        }

        public double getStart() {
            return this.start;
        }

        public double getEnd() {
            return this.end;
        }

        public double getRate() {
            return this.rate;
        }
    }
}

