/*
 * Decompiled with CFR 0.152.
 */
package ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.CPT.CPT;

import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.database.Item;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.database.Sequence;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.helpers.MemoryLogger;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.CPT.CPT.Bitvector;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.CPT.CPT.CPTHelper;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.CPT.CPT.PredictionTree;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Paramable;
import ca.pfv.spmf.algorithms.sequenceprediction.ipredict.predictor.Predictor;
import java.util.ArrayList;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Iterator;
import java.util.List;
import java.util.Map;

public class CPTPredictor
extends Predictor {
    private PredictionTree Root = new PredictionTree();
    private Map<Integer, PredictionTree> LT = new HashMap<Integer, PredictionTree>();
    private Map<Integer, Bitvector> II = new HashMap<Integer, Bitvector>();
    private Map<Integer, Float> CountTable;
    private String TAG = "CPT";
    private long nodeNumber = 0L;
    public Paramable parameters = new Paramable();

    public CPTPredictor() {
    }

    public CPTPredictor(String tag) {
        this();
        this.TAG = tag;
    }

    public CPTPredictor(String tag, String params) {
        this(tag);
        this.parameters.setParameter(params);
    }

    private Bitvector getMatchingSequences(Item[] targetArray) {
        Bitvector intersection = null;
        if (targetArray.length == 1) {
            intersection = this.II.get(targetArray[0].val);
        } else {
            for (int i = 0; i < targetArray.length; ++i) {
                Bitvector bitset = this.II.get(targetArray[i].val);
                if (bitset == null) continue;
                if (intersection == null) {
                    intersection = (Bitvector)bitset.clone();
                    continue;
                }
                intersection.and(bitset);
            }
        }
        if (intersection == null || intersection.cardinality() == 0) {
            return new Bitvector();
        }
        return intersection;
    }

    private void UpdateCountTable(Item[] targetArray, float weight, Map<Integer, Float> CountTable2, HashSet<Integer> hashSidVisited) {
        Bitvector indexes = this.getMatchingSequences(targetArray);
        if (indexes.cardinality() == 0) {
            return;
        }
        HashSet<Integer> hashTarget = new HashSet<Integer>(targetArray.length);
        for (Item it : targetArray) {
            hashTarget.add(it.val);
        }
        int index = indexes.nextSetBit(0);
        while (index >= 0) {
            if (!hashSidVisited.contains(index)) {
                PredictionTree curNode = this.LT.get(index);
                ArrayList<Item> branch = new ArrayList<Item>();
                branch.add(curNode.Item);
                while (curNode.Parent != null) {
                    curNode = curNode.Parent;
                    branch.add(curNode.Item);
                }
                int i = 0;
                HashSet<Integer> alreadySeen = new HashSet<Integer>();
                for (i = branch.size() - 1; i >= 0 && alreadySeen.size() != hashTarget.size(); --i) {
                    if (!hashTarget.contains(((Item)branch.get((int)i)).val)) continue;
                    alreadySeen.add(((Item)branch.get((int)i)).val);
                }
                int consequentEndPosition = i;
                for (i = 0; i <= consequentEndPosition; ++i) {
                    float oldValue = 0.0f;
                    if (CountTable2.containsKey(((Item)branch.get((int)i)).val)) {
                        oldValue = CountTable2.get(((Item)branch.get((int)i)).val).floatValue();
                    }
                    float curValue = 1.0f / (float)indexes.cardinality();
                    CountTable2.put(((Item)branch.get((int)i)).val, Float.valueOf(oldValue + curValue * weight));
                    hashSidVisited.add(index);
                }
            }
            index = indexes.nextSetBit(index + 1);
        }
    }

    private Sequence getBestSequenceFromCountTable(Map<Integer, Float> CountTable2) {
        double maxValue = -1.0;
        double secondMaxValue = -1.0;
        Integer maxItem = -1;
        for (Map.Entry<Integer, Float> it : CountTable2.entrySet()) {
            double lift = it.getValue().floatValue() / (float)this.II.get(it.getKey()).cardinality();
            double support = this.II.get(it.getKey()).cardinality();
            double confidence = it.getValue().floatValue();
            double score = confidence;
            if (score > maxValue) {
                secondMaxValue = maxValue;
                maxItem = it.getKey();
                maxValue = score;
                continue;
            }
            if (!(score > secondMaxValue)) continue;
            secondMaxValue = score;
        }
        Sequence predicted = new Sequence(-1);
        double diff = 1.0 - secondMaxValue / maxValue;
        if (maxItem != -1) {
            if (diff >= 0.0 || secondMaxValue == -1.0) {
                Item predictedItem = new Item(maxItem);
                predicted.addItem(predictedItem);
            } else if (diff == 0.0 && secondMaxValue != -1.0) {
                double highestScore = 0.0;
                int newBestItem = -1;
                for (Map.Entry<Integer, Float> it : CountTable2.entrySet()) {
                    double lift;
                    double score;
                    if (maxValue != (double)it.getValue().floatValue() || !this.II.containsKey(it.getKey()) || !((score = (lift = (double)(it.getValue().floatValue() / (float)this.II.get(it.getKey()).cardinality()))) > highestScore)) continue;
                    highestScore = score;
                    newBestItem = it.getKey();
                }
                Item item = new Item(newBestItem);
            }
        }
        return predicted;
    }

    @Override
    public Sequence Predict(Sequence target) {
        Iterator<Item> iter = target.getItems().iterator();
        while (iter.hasNext()) {
            Item item = iter.next();
            if (this.II.get(item.val) != null) continue;
            iter.remove();
        }
        Item[] targetArray = new Item[target.size()];
        for (int i = 0; i < target.getItems().size(); ++i) {
            targetArray[i] = target.get(i);
        }
        int initialTargetArraySize = targetArray.length;
        Sequence prediction = new Sequence(-1);
        int i = 0;
        int minRecursion = this.parameters.paramInt("recursiveDividerMin");
        int maxRecursion = this.parameters.paramInt("recursiveDividerMax") > targetArray.length ? targetArray.length : this.parameters.paramInt("recursiveDividerMax");
        HashSet<Integer> hashSidVisited = new HashSet<Integer>();
        this.CountTable = new HashMap<Integer, Float>();
        if (initialTargetArraySize == 1) {
            int size = targetArray.length;
            float weight = (float)size / (float)initialTargetArraySize;
            this.UpdateCountTable(targetArray, weight, this.CountTable, hashSidVisited);
        } else {
            for (i = minRecursion; i < maxRecursion && prediction.size() == 0; ++i) {
                int minSize = targetArray.length - i;
                this.RecursiveDivider(targetArray, minSize, this.CountTable, hashSidVisited, initialTargetArraySize);
            }
        }
        prediction = this.getBestSequenceFromCountTable(this.CountTable);
        return prediction;
    }

    public Map<Integer, Float> getCountTable() {
        return this.CountTable;
    }

    public void RecursiveDivider(Item[] targetArray, int minSize, Map<Integer, Float> countTable, HashSet<Integer> hashSidVisited, int initialTargetArraySize) {
        int size = targetArray.length;
        if (size <= minSize) {
            return;
        }
        float weight = (float)size / (float)initialTargetArraySize;
        this.UpdateCountTable(targetArray, weight, countTable, hashSidVisited);
        for (int toHide = 0; toHide < size; ++toHide) {
            Item[] newSequence = new Item[size - 1];
            int currentPosition = 0;
            for (int toUse = 0; toUse < size; ++toUse) {
                if (toUse == toHide) continue;
                newSequence[currentPosition++] = targetArray[toUse];
            }
            this.RecursiveDivider(newSequence, minSize, countTable, hashSidVisited, initialTargetArraySize);
        }
    }

    @Override
    public String getTAG() {
        return this.TAG;
    }

    @Override
    public Boolean Train(List<Sequence> trainingSequences) {
        this.nodeNumber = 0L;
        int seqId = 0;
        this.Root = new PredictionTree();
        this.LT = new HashMap<Integer, PredictionTree>();
        this.II = new HashMap<Integer, Bitvector>();
        MemoryLogger.addUpdate();
        ArrayList<Sequence> newTrainingSet = new ArrayList<Sequence>();
        for (Sequence seq : trainingSequences) {
            if (seq.size() > this.parameters.paramInt("splitLength") && this.parameters.paramInt("splitMethod") > 0) {
                if (this.parameters.paramInt("splitMethod") == 1) {
                    newTrainingSet.addAll(CPTHelper.sliceBasic(seq, this.parameters.paramInt("splitLength")));
                    continue;
                }
                newTrainingSet.addAll(CPTHelper.slice(seq, this.parameters.paramInt("splitLength")));
                continue;
            }
            newTrainingSet.add(seq);
        }
        for (Sequence curSeq : newTrainingSet) {
            PredictionTree curNode = this.Root;
            for (Item it : curSeq.getItems()) {
                if (!this.II.containsKey(it.val)) {
                    Bitvector tmpBitset = new Bitvector();
                    this.II.put(it.val, tmpBitset);
                }
                this.II.get(it.val).setBitAndIncrementCardinality(seqId);
                if (!curNode.hasChild(it).booleanValue()) {
                    curNode.addChild(it);
                    ++this.nodeNumber;
                }
                curNode = curNode.getChild(it);
            }
            this.LT.put(seqId, curNode);
            ++seqId;
        }
        int minSup = 0;
        Iterator<Map.Entry<Integer, Bitvector>> it = this.II.entrySet().iterator();
        while (it.hasNext()) {
            Map.Entry<Integer, Bitvector> pairs = it.next();
            if (pairs.getValue().cardinality() >= minSup) continue;
            it.remove();
        }
        MemoryLogger.addUpdate();
        return true;
    }

    @Override
    public long size() {
        return this.nodeNumber;
    }

    @Override
    public float memoryUsage() {
        float sizePredictionTree = this.nodeNumber * 3L * 4L;
        float sizeInvertedIndex = (float)((double)this.II.size() * (Math.ceil(this.LT.size() / 8) + 4.0));
        float sizeLookupTable = this.LT.size() * 2 * 4;
        return sizePredictionTree + sizeInvertedIndex + sizeLookupTable;
    }
}

