package gospl.algo.sr.bn;

import core.util.random.GenstarRandom;
import gospl.distribution.matrix.CachedNDimensionalMatrix;
import gospl.io.ipums.ReadIPUMSDictionaryUtils;
import java.io.File;
import java.io.FileWriter;
import java.io.IOException;
import java.lang.reflect.InvocationTargetException;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.commons.collections4.map.LRUMap;
import org.apache.commons.lang3.StringUtils;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

/* loaded from: input_file:gospl/algo/sr/bn/DNode.class */
public final class DNode {
    public static double cacheRatio = 0.5d;
    public static int minCachedCount = 500;
    public static int maxCachedCount = CachedNDimensionalMatrix.MAX_SIZE;
    private static Logger logger = LogManager.getLogger();
    public final Factor f;
    public final NodeCategorical n;
    public CategoricalBayesianNetwork bn;
    protected DNode left;
    protected DNode right;
    protected DNode parent;
    public Set<NodeCategorical> varsUnion;
    public Set<NodeCategorical> varsInter;
    public Set<NodeCategorical> cutset;
    public Set<NodeCategorical> acutset;
    public Set<NodeCategorical> context;
    public Set<NodeCategorical> cluster;
    private LRUMap<Map<NodeCategorical, String>, Double> cacheEvidenceInContext2proba;

    public DNode() {
        this.varsUnion = null;
        this.varsInter = null;
        this.cutset = null;
        this.acutset = null;
        this.context = null;
        this.cluster = null;
        this.cacheEvidenceInContext2proba = null;
        this.n = null;
        this.f = null;
        this.bn = null;
    }

    protected DNode(DNode dNode) {
        this.varsUnion = null;
        this.varsInter = null;
        this.cutset = null;
        this.acutset = null;
        this.context = null;
        this.cluster = null;
        this.cacheEvidenceInContext2proba = null;
        this.n = dNode.n;
        this.f = dNode.f == null ? null : dNode.f.m17clone();
        this.bn = dNode.bn;
        this.acutset = dNode.acutset;
        this.cacheEvidenceInContext2proba = dNode.cacheEvidenceInContext2proba;
        this.cluster = dNode.cluster;
        this.context = dNode.context;
        this.cutset = dNode.cutset;
        this.parent = null;
        if (dNode.right != null) {
            this.right = dNode.right.m14clone();
            this.right.parent = this;
        }
        if (dNode.left != null) {
            this.left = dNode.left.m14clone();
            this.left.parent = this;
        }
        resetCacheAll();
    }

    protected void resetCacheParents() {
        if (this.parent != null) {
            this.parent.resetCache();
            this.parent.resetCacheParents();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void resetCacheChildren() {
        if (this.left != null) {
            this.left.resetCache();
            this.left.resetCacheChildren();
        }
        if (this.right != null) {
            this.right.resetCache();
            this.right.resetCacheChildren();
        }
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public void resetCache() {
        if (this.cacheEvidenceInContext2proba != null) {
            this.cacheEvidenceInContext2proba.clear();
        }
        this.cluster = null;
        this.context = null;
        this.acutset = null;
        this.cutset = null;
        this.varsUnion = null;
        this.varsInter = null;
    }

    protected void resetCacheAll() {
        resetCache();
        resetCacheChildren();
        resetCacheParents();
    }

    private void computeVarsUnion() {
        varsUnion();
    }

    private void compute() {
        computeVarsUnion();
    }

    public int getDepth() {
        if (this.parent == null) {
            return 0;
        }
        return this.parent.getDepth() + 1;
    }

    /* renamed from: clone, reason: merged with bridge method [inline-methods] */
    public DNode m14clone() {
        return new DNode(this);
    }

    private final LRUMap<Map<NodeCategorical, String>, Double> getCache() {
        if (this.cacheEvidenceInContext2proba == null) {
            long j = 1;
            while (varsUnion().iterator().hasNext()) {
                j *= r0.next().getDomainSize();
            }
            int round = (int) Math.round(cacheRatio * j);
            if (round < minCachedCount && j > 0) {
                round = (int) Math.min(j, minCachedCount);
            }
            if (round > maxCachedCount || j < 0) {
                round = maxCachedCount;
            }
            if (round < 1) {
                round = 1;
            }
            logger.info("cache: max values to query: {}, will cache {}\n{}", Long.valueOf(j), Integer.valueOf(round), this);
            this.cacheEvidenceInContext2proba = new LRUMap<>(round);
        }
        return this.cacheEvidenceInContext2proba;
    }

    public DNode(NodeCategorical nodeCategorical) {
        this.varsUnion = null;
        this.varsInter = null;
        this.cutset = null;
        this.acutset = null;
        this.context = null;
        this.cluster = null;
        this.cacheEvidenceInContext2proba = null;
        this.n = nodeCategorical;
        this.f = nodeCategorical.asFactor();
        this.bn = nodeCategorical.cNetwork;
    }

    public DNode(CategoricalBayesianNetwork categoricalBayesianNetwork) {
        this.varsUnion = null;
        this.varsInter = null;
        this.cutset = null;
        this.acutset = null;
        this.context = null;
        this.cluster = null;
        this.cacheEvidenceInContext2proba = null;
        this.n = null;
        this.f = null;
        this.bn = categoricalBayesianNetwork;
    }

    protected static DNode compose(CategoricalBayesianNetwork categoricalBayesianNetwork, Set<DNode> set) {
        if (set.size() == 1) {
            return set.iterator().next();
        }
        logger.trace("composing {}", set);
        ArrayList arrayList = new ArrayList(set);
        Collections.shuffle(arrayList);
        int size = arrayList.size() / 2;
        HashSet hashSet = new HashSet(arrayList.subList(0, size));
        HashSet hashSet2 = new HashSet(arrayList.subList(size, arrayList.size()));
        DNode dNode = new DNode(categoricalBayesianNetwork);
        compose(categoricalBayesianNetwork, hashSet).becomeLeftChild(dNode);
        compose(categoricalBayesianNetwork, hashSet2).becomeRightChild(dNode);
        logger.debug("composed {} into {}", set, dNode);
        dNode.resetCacheAll();
        return dNode;
    }

    /* JADX INFO: Access modifiers changed from: protected */
    public static DNode eliminationOrder2DTree(CategoricalBayesianNetwork categoricalBayesianNetwork, List<NodeCategorical> list) {
        logger.debug("create DTree from elimination order {}", list);
        Set set = (Set) categoricalBayesianNetwork.getNodes().stream().filter(nodeCategorical -> {
            return list.contains(nodeCategorical);
        }).map(nodeCategorical2 -> {
            return new DNode(nodeCategorical2);
        }).collect(Collectors.toSet());
        for (NodeCategorical nodeCategorical3 : list) {
            logger.debug("now decomposing on {}", nodeCategorical3);
            Set set2 = (Set) set.stream().filter(dNode -> {
                return dNode.varsInter().contains(nodeCategorical3);
            }).collect(Collectors.toSet());
            if (set2.size() > 1) {
                logger.debug("composition of {}", set2);
                set.removeAll(set2);
                set.add(compose(categoricalBayesianNetwork, set2));
            }
        }
        logger.debug("final composition of {}", set);
        DNode compose = compose(categoricalBayesianNetwork, set);
        compose.setNetwork(categoricalBayesianNetwork);
        logger.debug("decomposed into: {}", compose);
        return compose;
    }

    public void setLeft(DNode dNode) {
        if (this.f != null) {
            throw new IllegalArgumentException("cannot add a child to a factor node");
        }
        this.left = dNode;
        resetCacheAll();
    }

    public void setRight(DNode dNode) {
        if (this.f != null) {
            throw new IllegalArgumentException("cannot add a child to a factor node");
        }
        this.right = dNode;
        resetCacheAll();
    }

    public boolean isRoot() {
        return this.parent == null;
    }

    public boolean isLeaf() {
        return this.left == null && this.right == null;
    }

    public boolean isInternal() {
        if (this.parent != null) {
            if ((this.left != null) & (this.right != null)) {
                return true;
            }
        }
        return false;
    }

    public Set<NodeCategorical> varsUnion() {
        if (this.left == null && this.right == null) {
            return this.f.variables;
        }
        if (this.varsUnion == null) {
            this.varsUnion = new HashSet(this.left.varsUnion());
            this.varsUnion.addAll(this.right.varsUnion());
            logger.trace("computed vars {}", this.varsUnion);
        }
        return this.varsUnion;
    }

    public Set<NodeCategorical> varsInter() {
        if (this.left == null && this.right == null) {
            return this.f.variables;
        }
        if (this.varsInter == null) {
            this.varsInter = new HashSet(this.left.varsInter());
            this.varsInter.retainAll(this.right.varsInter());
            logger.trace("computed varsInter {}", this.varsInter);
        }
        return this.varsInter;
    }

    private Set<NodeCategorical> cutsetWithACutset() {
        HashSet hashSet = new HashSet(this.left.varsUnion());
        hashSet.retainAll(this.right.varsUnion());
        return hashSet;
    }

    public Set<NodeCategorical> cutset() {
        if (this.cutset == null) {
            HashSet hashSet = new HashSet(this.left.varsUnion());
            hashSet.retainAll(this.right.varsUnion());
            hashSet.removeAll(acutset());
            this.cutset = hashSet;
            logger.trace("computed cutset {}", hashSet);
        }
        return this.cutset;
    }

    protected void getUnionCutsetParents(Set<NodeCategorical> set) {
        set.addAll(cutset());
        if (this.parent != null) {
            this.parent.getUnionCutsetParents(set);
        }
    }

    public Set<NodeCategorical> acutset() {
        if (isRoot()) {
            return Collections.emptySet();
        }
        if (this.acutset == null) {
            HashSet hashSet = new HashSet(this.bn.getNodes().size());
            this.parent.getUnionCutsetParents(hashSet);
            this.acutset = hashSet;
            logger.trace("computed acutset {}", hashSet);
        }
        return this.acutset;
    }

    public Set<NodeCategorical> context() {
        if (this.context == null) {
            HashSet hashSet = new HashSet(varsUnion());
            hashSet.retainAll(acutset());
            this.context = hashSet;
        }
        return this.context;
    }

    public Set<NodeCategorical> cluster() {
        if (isLeaf()) {
            return varsUnion();
        }
        if (this.cluster == null) {
            this.cluster = new HashSet(cutset());
            this.cluster.addAll(context());
        }
        return this.context;
    }

    public int contextWidth() {
        int size = context().size();
        if (!isLeaf()) {
            size = Math.max(size, Math.max(this.left.contextWidth(), this.right.contextWidth()));
        }
        return size;
    }

    protected double lookup(Map<NodeCategorical, String> map) {
        logger.debug("Lookup on {} for {}", this.f, map);
        if (!map.containsKey(this.n)) {
            logger.trace("Not concerned by evidence => 1.");
            return 1.0d;
        }
        HashMap hashMap = new HashMap(map);
        hashMap.keySet().retainAll(this.f.variables);
        return this.f.get(hashMap);
    }

    public double recursiveConditionning(String... strArr) {
        return recursiveConditionning(this.bn.toNodeAndValue(strArr));
    }

    public double recursiveConditionning(Map<NodeCategorical, String> map) {
        if (map.isEmpty()) {
            return 1.0d;
        }
        if (logger.isDebugEnabled()) {
            logger.debug("Recursive Conditionning for {} on:\n {}", map, this);
        }
        if (isLeaf()) {
            logger.trace("is leaf => lookup");
            return lookup(map);
        }
        HashMap hashMap = new HashMap(map);
        hashMap.keySet().retainAll(varsUnion());
        if (!context().isEmpty() && !hashMap.isEmpty()) {
            logger.trace("search in cache {}", hashMap);
            Double d = (Double) getCache().get(hashMap);
            if (d != null) {
                InferencePerformanceUtils.singleton.incCacheHit();
                return d.doubleValue();
            }
            InferencePerformanceUtils.singleton.incCacheMiss();
        }
        if (logger.isTraceEnabled()) {
            logger.trace("no leaf => summing over cutset {}", cutset().stream().map(nodeCategorical -> {
                return nodeCategorical.name;
            }).collect(Collectors.joining(",")));
        }
        double d2 = 0.0d;
        HashSet hashSet = new HashSet(cutset());
        hashSet.removeAll(map.keySet());
        if (logger.isTraceEnabled()) {
            logger.trace("no leaf => exploring combinations over {}", hashSet.stream().map(nodeCategorical2 -> {
                return nodeCategorical2.name;
            }).collect(Collectors.joining(",")));
        }
        DNode dNode = this.left;
        DNode dNode2 = this.right;
        IteratorCategoricalVariables iterateDomains = this.bn.iterateDomains(hashSet);
        while (true) {
            if (!iterateDomains.hasNext()) {
                break;
            }
            Map<NodeCategorical, String> next = iterateDomains.next();
            next.putAll(map);
            if (logger.isTraceEnabled()) {
                logger.trace("Sum of {} over\n {}", next, this);
            }
            double recursiveConditionning = dNode.recursiveConditionning(next);
            if (logger.isTraceEnabled()) {
                logger.trace("Sum of {} = {}  ", next, Double.valueOf(recursiveConditionning));
            }
            if (recursiveConditionning != 0.0d) {
                double recursiveConditionning2 = dNode2.recursiveConditionning(next);
                if (logger.isTraceEnabled()) {
                    logger.trace("Sum of {} = {}", next, Double.valueOf(recursiveConditionning2));
                }
                d2 += recursiveConditionning * recursiveConditionning2;
                InferencePerformanceUtils.singleton.incAdditions();
                InferencePerformanceUtils.singleton.incMultiplications();
                if (d2 >= 1.0d) {
                    d2 = 1.0d;
                    break;
                }
            }
        }
        if (this.cacheEvidenceInContext2proba != null && !hashMap.isEmpty()) {
            this.cacheEvidenceInContext2proba.put(hashMap, Double.valueOf(d2));
        }
        return d2;
    }

    private final boolean shouldComputeFirst(DNode dNode, DNode dNode2, Map<NodeCategorical, String> map) {
        return dNode.isLeaf() ? !dNode2.isLeaf() || dNode.f.variables.size() <= dNode2.f.variables.size() : !dNode2.isLeaf() && dNode.cutset().size() <= dNode2.cutset().size();
    }

    public void checkConsistency() {
        if (this.f != null && (this.left != null || this.left != null)) {
            throw new IllegalArgumentException("only leafs can contain CPTs");
        }
        if ((this.left == null) ^ (this.right == null)) {
            throw new IllegalArgumentException("can have either zero or two children");
        }
        if (this.parent != null && this.parent.right != this && this.parent.left != this) {
            throw new IllegalArgumentException("inconsistent hierarchy: parent should have us as a child");
        }
        if (this.left != null) {
            if (this.left.parent != this) {
                throw new IllegalArgumentException("child should have us as a parent ");
            }
            this.left.checkConsistency();
        }
        if (this.right != null) {
            if (this.right.parent != this) {
                throw new IllegalArgumentException("child should have us as a parent ");
            }
            this.right.checkConsistency();
        }
    }

    public String toString() {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append(StringUtils.repeat(ReadIPUMSDictionaryUtils.TAB_SEPARATOR, getDepth()));
        if (isLeaf()) {
            stringBuffer.append("|-- ").append(this.n.toString()).append("\n");
        } else {
            stringBuffer.append("|-- vars: ").append((String) varsInter().stream().map(nodeCategorical -> {
                return nodeCategorical.name;
            }).collect(Collectors.joining(",")));
            stringBuffer.append(" cutset: ").append((String) cutset().stream().map(nodeCategorical2 -> {
                return nodeCategorical2.name;
            }).collect(Collectors.joining(",")));
            stringBuffer.append(" acutset:").append((String) acutset().stream().map(nodeCategorical3 -> {
                return nodeCategorical3.name;
            }).collect(Collectors.joining(",")));
            stringBuffer.append(" context:").append((String) context().stream().map(nodeCategorical4 -> {
                return nodeCategorical4.name;
            }).collect(Collectors.joining(",")));
            stringBuffer.append("\n");
            stringBuffer.append(this.left.toString());
            stringBuffer.append(this.right.toString());
        }
        return stringBuffer.toString();
    }

    public void becomeRightChild(DNode dNode) {
        this.parent = dNode;
        this.bn = dNode.bn;
        if (dNode.right != null && dNode.right != this) {
            dNode.right.parent = null;
        }
        dNode.setRight(this);
        resetCacheAll();
    }

    public void becomeLeftChild(DNode dNode) {
        this.parent = dNode;
        this.bn = dNode.bn;
        if (dNode.left != null && dNode.left != this) {
            dNode.left.parent = null;
        }
        dNode.setLeft(this);
        resetCacheAll();
    }

    public void setNetwork(CategoricalBayesianNetwork categoricalBayesianNetwork) {
        this.bn = categoricalBayesianNetwork;
    }

    public void instanciate(Map<NodeCategorical, String> map) {
        logger.debug("reducing {} based on evidence {}", this, map);
        if (isLeaf()) {
            this.f.reduce(map);
        } else {
            this.right.instanciate(map);
            this.left.instanciate(map);
        }
        resetCacheAll();
    }

    public void generate(Map<NodeCategorical, String> map) {
        logger.trace("generating for {} and known {}", this, map);
        double nextDouble = GenstarRandom.getInstance().nextDouble();
        if (!isLeaf()) {
            logger.trace("calling left {}", this.left);
            this.left.generate(map);
            logger.trace("calling right {}", this.right);
            this.right.generate(map);
            return;
        }
        if (map.containsKey(this.n)) {
            return;
        }
        logger.trace("picking a value from our factor {}", this.f);
        double size = nextDouble * this.n.getParents().size();
        double d = 0.0d;
        for (Map.Entry<Map<NodeCategorical, String>, Double> entry : this.f.reduction(map).values.entrySet()) {
            d += entry.getValue().doubleValue();
            if (d >= size) {
                map.putAll(entry.getKey());
                logger.trace("picked from CPT: {}", entry.getKey());
                return;
            }
        }
    }

    public final void exportAsGraphviz(File file) {
        exportAsGraphviz(file, "cutset");
    }

    public final void exportAsGraphviz(File file, String str) {
        StringBuffer stringBuffer = new StringBuffer();
        stringBuffer.append("# to generate it, use:\n# dot -Tjpg " + file.getAbsolutePath() + " -o " + file.getAbsolutePath() + ".jpg\n");
        stringBuffer.append("# if you also want to open it under linux:\n# dot -Tjpg " + file.getAbsolutePath() + " -o " + file.getAbsolutePath() + ".jpg && xdg-open " + file.getAbsolutePath() + ".jpg\n");
        stringBuffer.append("digraph dtree {\n");
        stringBuffer.append("\trankdir=TB;\n");
        exportAsGraphvizInto(stringBuffer, new HashMap(), str);
        stringBuffer.append("}\n");
        try {
            FileWriter fileWriter = new FileWriter(file);
            fileWriter.write(stringBuffer.toString());
            fileWriter.close();
        } catch (IOException e) {
            e.printStackTrace();
            throw new RuntimeException("error while exporting dtree into file " + file, e);
        }
    }

    private void exportAsGraphvizInto(StringBuffer stringBuffer, Map<DNode, String> map, String str) {
        String str2 = map.get(this);
        if (str2 == null) {
            str2 = Integer.toString(map.size());
            map.put(this, str2);
        }
        stringBuffer.append("\t\"").append(str2).append("\" [label=\"");
        if (this.n != null) {
            stringBuffer.append(this.n);
        } else {
            try {
                stringBuffer.append(str).append(": {").append((String) ((Set) getClass().getDeclaredMethod(str, new Class[0]).invoke(this, new Object[0])).stream().map(nodeCategorical -> {
                    return nodeCategorical.name;
                }).collect(Collectors.joining(","))).append("}");
            } catch (IllegalAccessException | IllegalArgumentException | NoSuchMethodException | SecurityException | InvocationTargetException e) {
                e.printStackTrace();
            }
        }
        stringBuffer.append("\"];\n");
        if (this.parent != null && map.containsKey(this.parent)) {
            stringBuffer.append("\t\"").append(map.get(this.parent)).append("\" -> ").append("\"").append(str2).append("\";\n");
        }
        if (this.left != null) {
            this.left.exportAsGraphvizInto(stringBuffer, map, str);
        }
        if (this.right != null) {
            this.right.exportAsGraphvizInto(stringBuffer, map, str);
        }
    }

    public void reduce(Map<NodeCategorical, String> map) {
        if (this.f != null) {
            this.f.reduce(map);
        }
        if (this.right != null) {
            this.right.reduce(map);
        }
        if (this.left != null) {
            this.left.reduce(map);
        }
        resetCache();
    }
}
