package gospl.algo.sr.bn;

import core.metamodel.attribute.Attribute;
import core.metamodel.attribute.AttributeFactory;
import core.metamodel.entity.ADemoEntity;
import core.metamodel.value.IValue;
import core.util.data.GSEnumDataType;
import core.util.excpetion.GSIllegalRangedData;
import gospl.sampler.ICompletionSampler;
import java.util.HashMap;
import java.util.Map;

/* loaded from: input_file:gospl/algo/sr/bn/BayesianNetworkCompletionSampler.class */
public class BayesianNetworkCompletionSampler implements ICompletionSampler<ADemoEntity> {
    private final CategoricalBayesianNetwork bn;
    private final AbstractInferenceEngine engine;
    private final Map<String, Attribute<? extends IValue>> bnVariable2popAttribute;
    private final Map<Attribute<? extends IValue>, NodeCategorical> popAttribute2bnVariable;

    public BayesianNetworkCompletionSampler(CategoricalBayesianNetwork categoricalBayesianNetwork) throws GSIllegalRangedData {
        this(categoricalBayesianNetwork, new EliminationInferenceEngine(categoricalBayesianNetwork));
    }

    public BayesianNetworkCompletionSampler(CategoricalBayesianNetwork categoricalBayesianNetwork, AbstractInferenceEngine abstractInferenceEngine) throws GSIllegalRangedData {
        this.bn = categoricalBayesianNetwork;
        this.engine = abstractInferenceEngine;
        this.bnVariable2popAttribute = new HashMap(categoricalBayesianNetwork.getNodes().size());
        for (NodeCategorical nodeCategorical : categoricalBayesianNetwork.getNodes()) {
            this.bnVariable2popAttribute.put(nodeCategorical.name, AttributeFactory.getFactory().createAttribute(nodeCategorical.getName(), GSEnumDataType.Nominal, nodeCategorical.getDomain()));
        }
        this.popAttribute2bnVariable = new HashMap(categoricalBayesianNetwork.getNodes().size());
    }

    protected NodeCategorical getBNVariableForAttribute(Attribute<? extends IValue> attribute) {
        NodeCategorical nodeCategorical;
        if (this.popAttribute2bnVariable.containsKey(attribute)) {
            nodeCategorical = this.popAttribute2bnVariable.get(attribute);
        } else {
            nodeCategorical = this.bn.getVariable(attribute.getAttributeName());
            this.popAttribute2bnVariable.put(attribute, nodeCategorical);
            this.bnVariable2popAttribute.put(nodeCategorical.getName(), attribute);
        }
        return nodeCategorical;
    }

    protected Attribute<? extends IValue> getPopulationAttributeForBNVariable(NodeCategorical nodeCategorical) {
        Attribute<? extends IValue> createAttribute;
        this.bnVariable2popAttribute.get(nodeCategorical.name);
        if (this.bnVariable2popAttribute.containsKey(nodeCategorical.name)) {
            createAttribute = this.bnVariable2popAttribute.get(nodeCategorical.name);
        } else {
            try {
                createAttribute = AttributeFactory.getFactory().createAttribute(nodeCategorical.getName(), GSEnumDataType.Nominal, nodeCategorical.getDomain());
                this.bnVariable2popAttribute.put(nodeCategorical.getName(), createAttribute);
            } catch (GSIllegalRangedData e) {
                throw new RuntimeException("unable to create attribute", e);
            }
        }
        return createAttribute;
    }

    @Override // gospl.sampler.ICompletionSampler
    public ADemoEntity complete(ADemoEntity aDemoEntity) {
        for (Attribute<? extends IValue> attribute : aDemoEntity.getAttributes()) {
            NodeCategorical bNVariableForAttribute = getBNVariableForAttribute(attribute);
            if (bNVariableForAttribute != null) {
                this.engine.addEvidence(bNVariableForAttribute, aDemoEntity.getValueForAttribute(attribute).getStringValue());
            }
        }
        System.err.println("inference with evidence : " + this.engine.evidenceVariable2value);
        System.err.println("p(evidence): " + this.engine.getProbabilityEvidence());
        Map<NodeCategorical, String> sampleOne = this.engine.sampleOne();
        ADemoEntity clone = aDemoEntity.clone();
        for (Map.Entry<NodeCategorical, String> entry : sampleOne.entrySet()) {
            Attribute<? extends IValue> populationAttributeForBNVariable = getPopulationAttributeForBNVariable(entry.getKey());
            if (!clone.hasAttribute(populationAttributeForBNVariable)) {
                clone.setAttributeValue(populationAttributeForBNVariable, populationAttributeForBNVariable.getValueSpace().addValue(entry.getValue()));
            }
        }
        return clone;
    }

    @Override // gospl.sampler.ICompletionSampler
    public String toCsv(String str) {
        return null;
    }
}
