/*
 * Decompiled with CFR 0.152.
 */
package org.apache.sysds.runtime.transform.encode;

import java.io.IOException;
import java.io.ObjectInput;
import java.io.ObjectOutput;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import org.apache.commons.lang.ArrayUtils;
import org.apache.sysds.api.DMLScript;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.functionobjects.KahanPlus;
import org.apache.sysds.runtime.functionobjects.Mean;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.data.FrameBlock;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.transform.TfUtils;
import org.apache.sysds.runtime.transform.encode.LegacyEncoder;
import org.apache.sysds.runtime.transform.meta.TfMetaUtils;
import org.apache.sysds.runtime.util.IndexRange;
import org.apache.sysds.runtime.util.UtilFunctions;
import org.apache.sysds.utils.stats.TransformStatistics;
import org.apache.wink.json4j.JSONArray;
import org.apache.wink.json4j.JSONException;
import org.apache.wink.json4j.JSONObject;

public class EncoderMVImpute
extends LegacyEncoder {
    private static final long serialVersionUID = 9057868620144662194L;
    private final Mean _meanFn = Mean.getMeanFnObject();
    private MVMethod[] _mvMethodList = null;
    private KahanObject[] _meanList = null;
    private long[] _countList = null;
    private String[] _replacementList = null;
    private List<Integer> _rcList = null;
    private HashMap<Integer, HashMap<String, Long>> _hist = null;

    public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int clen, int minCol, int maxCol) throws JSONException {
        super(null, clen);
        int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfUtils.TfMethod.IMPUTE.toString(), minCol, maxCol);
        this.initColList(collist);
        this.parseMethodsAndReplacements(parsedSpec, colnames, minCol);
        this._hist = new HashMap();
    }

    public EncoderMVImpute() {
        super(new int[0], 0);
    }

    public EncoderMVImpute(int[] colList, MVMethod[] mvMethodList, String[] replacementList, KahanObject[] meanList, long[] countList, List<Integer> rcList, int clen) {
        super(colList, clen);
        this._mvMethodList = mvMethodList;
        this._replacementList = replacementList;
        this._meanList = meanList;
        this._countList = countList;
        this._rcList = rcList;
    }

    private static void fillListsFromMap(Map<Integer, ColInfo> map, int[] colList, MVMethod[] mvMethodList, String[] replacementList, KahanObject[] meanList, long[] countList, HashMap<Integer, HashMap<String, Long>> hist) {
        int i = 0;
        for (Map.Entry<Integer, ColInfo> entry : map.entrySet()) {
            colList[i] = entry.getKey();
            mvMethodList[i] = entry.getValue()._method;
            replacementList[i] = entry.getValue()._replacement;
            meanList[i] = entry.getValue()._mean;
            countList[i++] = entry.getValue()._count;
            hist.put(entry.getKey(), entry.getValue()._hist);
        }
    }

    public String[] getReplacements() {
        return this._replacementList;
    }

    public KahanObject[] getMeans() {
        return this._meanList;
    }

    private void parseMethodsAndReplacements(JSONObject parsedSpec, String[] colnames, int offset) throws JSONException {
        JSONArray mvspec = (JSONArray)parsedSpec.get(TfUtils.TfMethod.IMPUTE.toString());
        boolean ids = parsedSpec.containsKey("ids") && parsedSpec.getBoolean("ids");
        this._mvMethodList = new MVMethod[mvspec.size()];
        this._replacementList = new String[mvspec.size()];
        this._meanList = new KahanObject[mvspec.size()];
        this._countList = new long[mvspec.size()];
        Arrays.sort(this._colList);
        int listIx = 0;
        for (Object o : mvspec) {
            JSONObject mvobj = (JSONObject)o;
            int ixOffset = offset == -1 ? 0 : offset - 1;
            int pos = Arrays.binarySearch(this._colList, ids ? mvobj.getInt("id") - ixOffset : ArrayUtils.indexOf((Object[])colnames, (Object)mvobj.get("name")) + 1);
            if (pos < 0) continue;
            this._mvMethodList[listIx] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase());
            if (this._mvMethodList[listIx] == MVMethod.CONSTANT) {
                this._replacementList[listIx] = mvobj.getString("value");
            }
            this._meanList[listIx++] = new KahanObject(0.0, 0.0);
        }
        this._mvMethodList = Arrays.copyOf(this._mvMethodList, listIx);
        this._replacementList = Arrays.copyOf(this._replacementList, listIx);
        this._meanList = Arrays.copyOf(this._meanList, listIx);
        this._countList = Arrays.copyOf(this._countList, listIx);
    }

    public MVMethod getMethod(int colID) {
        int idx = this.isApplicable(colID);
        if (idx == -1) {
            return MVMethod.INVALID;
        }
        return this._mvMethodList[idx];
    }

    public long getNonMVCount(int colID) {
        int idx = this.isApplicable(colID);
        return idx == -1 ? 0L : this._countList[idx];
    }

    public String getReplacement(int colID) {
        int idx = this.isApplicable(colID);
        return idx == -1 ? null : this._replacementList[idx];
    }

    @Override
    public MatrixBlock encode(FrameBlock in, MatrixBlock out) {
        this.build(in);
        return this.apply(in, out);
    }

    @Override
    public void build(FrameBlock in) {
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        try {
            for (int j = 0; j < this._colList.length; ++j) {
                int colID = this._colList[j];
                if (this._mvMethodList[j] == MVMethod.GLOBAL_MEAN) {
                    long off = this._countList[j];
                    for (int i = 0; i < in.getNumRows(); ++i) {
                        Object key = in.get(i, colID - 1);
                        if (key == null) {
                            --off;
                            continue;
                        }
                        this._meanFn.execute2(this._meanList[j], UtilFunctions.objectToDouble(in.getSchema()[colID - 1], key), off + (long)i + 1L);
                    }
                    this._replacementList[j] = String.valueOf(this._meanList[j]._sum);
                    int n = j;
                    this._countList[n] = this._countList[n] + (long)in.getNumRows();
                    continue;
                }
                if (this._mvMethodList[j] != MVMethod.GLOBAL_MODE) continue;
                HashMap<String, Long> hist = this._hist.containsKey(colID) ? this._hist.get(colID) : new HashMap<String, Long>();
                for (int i = 0; i < in.getNumRows(); ++i) {
                    String key = String.valueOf(in.get(i, colID - 1));
                    if (key.equals("null") || key.isEmpty()) continue;
                    Long val = (Long)hist.get(key);
                    hist.put(key, val != null ? val + 1L : 1L);
                }
                this._hist.put(colID, hist);
                long max = Long.MIN_VALUE;
                for (Map.Entry e : hist.entrySet()) {
                    if ((Long)e.getValue() <= max) continue;
                    this._replacementList[j] = (String)e.getKey();
                    max = (Long)e.getValue();
                }
            }
        }
        catch (Exception ex) {
            throw new RuntimeException(ex);
        }
        if (DMLScript.STATISTICS) {
            TransformStatistics.incImputeBuildTime(System.nanoTime() - t0);
        }
    }

    @Override
    public MatrixBlock apply(FrameBlock in, MatrixBlock out) {
        long t0 = DMLScript.STATISTICS ? System.nanoTime() : 0L;
        for (int i = 0; i < in.getNumRows(); ++i) {
            for (int j = 0; j < this._colList.length; ++j) {
                int colID = this._colList[j];
                if (!Double.isNaN(out.quickGetValue(i, colID - 1))) continue;
                out.quickSetValue(i, colID - 1, Double.parseDouble(this._replacementList[j]));
            }
        }
        if (DMLScript.STATISTICS) {
            TransformStatistics.incImputeApplyTime(System.nanoTime() - t0);
        }
        return out;
    }

    @Override
    public LegacyEncoder subRangeEncoder(IndexRange ixRange) {
        HashMap<Integer, ColInfo> map = new HashMap<Integer, ColInfo>();
        for (int i2 = 0; i2 < this._colList.length; ++i2) {
            int col = this._colList[i2];
            if (!ixRange.inColRange(col)) continue;
            map.put(this._colList[i2], new ColInfo(this._mvMethodList[i2], this._replacementList[i2], this._meanList[i2], this._countList[i2], this._hist.get(i2)));
        }
        if (map.size() == 0) {
            return null;
        }
        int[] colList = new int[map.size()];
        MVMethod[] mvMethodList = new MVMethod[map.size()];
        String[] replacementList = new String[map.size()];
        KahanObject[] meanList = new KahanObject[map.size()];
        long[] countList = new long[map.size()];
        EncoderMVImpute.fillListsFromMap(map, colList, mvMethodList, replacementList, meanList, countList, this._hist);
        if (this._rcList == null) {
            this._rcList = new ArrayList<Integer>();
        }
        List<Integer> rcList = this._rcList.stream().filter(ixRange::inColRange).map(i -> (int)((long)i.intValue() - (ixRange.colStart - 1L))).collect(Collectors.toList());
        return new EncoderMVImpute(colList, mvMethodList, replacementList, meanList, countList, rcList, (int)ixRange.colSpan());
    }

    @Override
    public void mergeAt(LegacyEncoder other, int row, int col) {
        if (other instanceof EncoderMVImpute) {
            int i2;
            EncoderMVImpute otherImpute = (EncoderMVImpute)other;
            HashMap<Integer, ColInfo> map = new HashMap<Integer, ColInfo>();
            for (i2 = 0; i2 < this._colList.length; ++i2) {
                map.put(this._colList[i2], new ColInfo(this._mvMethodList[i2], this._replacementList[i2], this._meanList[i2], this._countList[i2], this._hist.get(i2 + 1)));
            }
            for (i2 = 0; i2 < other._colList.length; ++i2) {
                int column = other._colList[i2];
                ColInfo otherColInfo = new ColInfo(otherImpute._mvMethodList[i2], otherImpute._replacementList[i2], otherImpute._meanList[i2], otherImpute._countList[i2], otherImpute._hist.get(i2 + 1));
                ColInfo colInfo = (ColInfo)map.get(column);
                if (colInfo == null) {
                    map.put(column, otherColInfo);
                    continue;
                }
                colInfo.merge(otherColInfo);
            }
            this._colList = new int[map.size()];
            this._mvMethodList = new MVMethod[map.size()];
            this._replacementList = new String[map.size()];
            this._meanList = new KahanObject[map.size()];
            this._countList = new long[map.size()];
            this._hist = new HashMap();
            EncoderMVImpute.fillListsFromMap(map, this._colList, this._mvMethodList, this._replacementList, this._meanList, this._countList, this._hist);
            if (this._rcList == null) {
                this._rcList = new ArrayList<Integer>();
            }
            HashSet<Integer> rcSet = new HashSet<Integer>(this._rcList);
            rcSet.addAll(otherImpute._rcList.stream().map(i -> i + (col - 1)).collect(Collectors.toSet()));
            this._rcList = new ArrayList<Integer>(rcSet);
            return;
        }
        super.mergeAt(other, row, col);
    }

    @Override
    public FrameBlock getMetaData(FrameBlock out) {
        for (int j = 0; j < this._colList.length; ++j) {
            out.getColumnMetadata(this._colList[j] - 1).setMvValue(this._replacementList[j]);
        }
        return out;
    }

    @Override
    public void initMetaData(FrameBlock meta) {
        for (int j = 0; j < this._colList.length; ++j) {
            int colID = this._colList[j];
            String mvVal = UtilFunctions.unquote(meta.getColumnMetadata(colID - 1).getMvValue());
            if (this._rcList.contains(colID)) {
                Long mvVal2 = meta.getRecodeMap(colID - 1).get(mvVal);
                if (mvVal2 == null) {
                    throw new RuntimeException("Missing recode value for impute value '" + mvVal + "' (colID=" + colID + ").");
                }
                this._replacementList[j] = mvVal2.toString();
                continue;
            }
            this._replacementList[j] = mvVal;
        }
    }

    public void initRecodeIDList(List<Integer> rcList) {
        this._rcList = rcList;
    }

    public HashMap<String, Long> getHistogram(int colID) {
        return this._hist.get(colID);
    }

    @Override
    public void writeExternal(ObjectOutput out) throws IOException {
        super.writeExternal(out);
        for (int i = 0; i < this._colList.length; ++i) {
            out.writeByte(this._mvMethodList[i].ordinal());
            out.writeLong(this._countList[i]);
        }
        ArrayList<String> notNullReplacements = new ArrayList<String>(Arrays.asList(this._replacementList));
        notNullReplacements.removeAll(Collections.singleton(null));
        out.writeInt(notNullReplacements.size());
        for (int i = 0; i < this._replacementList.length; ++i) {
            if (this._replacementList[i] == null) continue;
            out.writeInt(i);
            out.writeUTF(this._replacementList[i]);
        }
        out.writeInt(this._rcList.size());
        for (int rc : this._rcList) {
            out.writeInt(rc);
        }
        int histSize = this._hist == null ? 0 : this._hist.size();
        out.writeInt(histSize);
        if (histSize > 0) {
            for (Map.Entry<Integer, HashMap<String, Long>> e1 : this._hist.entrySet()) {
                out.writeInt(e1.getKey());
                out.writeInt(e1.getValue().size());
                for (Map.Entry<String, Long> e2 : e1.getValue().entrySet()) {
                    out.writeUTF(e2.getKey());
                    out.writeLong(e2.getValue());
                }
            }
        }
    }

    @Override
    public void readExternal(ObjectInput in) throws IOException {
        super.readExternal(in);
        this._mvMethodList = new MVMethod[this._colList.length];
        this._countList = new long[this._colList.length];
        this._meanList = new KahanObject[this._colList.length];
        this._replacementList = new String[this._colList.length];
        for (int i = 0; i < this._colList.length; ++i) {
            this._mvMethodList[i] = MVMethod.values()[in.readByte()];
            this._countList[i] = in.readLong();
            this._meanList[i] = new KahanObject(0.0, 0.0);
        }
        int size4 = in.readInt();
        for (int i = 0; i < size4; ++i) {
            int index = in.readInt();
            this._replacementList[index] = in.readUTF();
        }
        int size3 = in.readInt();
        this._rcList = new ArrayList<Integer>();
        for (int j = 0; j < size3; ++j) {
            this._rcList.add(in.readInt());
        }
        this._hist = new HashMap();
        int size1 = in.readInt();
        for (int i = 0; i < size1; ++i) {
            Integer key1 = in.readInt();
            int size2 = in.readInt();
            HashMap<String, Long> maps = new HashMap<String, Long>();
            for (int j = 0; j < size2; ++j) {
                String key2 = in.readUTF();
                Long value = in.readLong();
                maps.put(key2, value);
            }
            this._hist.put(key1, maps);
        }
    }

    private static class ColInfo {
        MVMethod _method;
        String _replacement;
        KahanObject _mean;
        long _count;
        HashMap<String, Long> _hist;

        ColInfo(MVMethod method, String replacement, KahanObject mean, long count, HashMap<String, Long> hist) {
            this._method = method;
            this._replacement = replacement;
            this._mean = mean;
            this._count = count;
            this._hist = hist;
        }

        public void merge(ColInfo otherColInfo) {
            if (this._method != otherColInfo._method) {
                throw new DMLRuntimeException("Tried to merge two different impute methods: " + this._method.name() + " vs. " + otherColInfo._method.name());
            }
            switch (this._method) {
                case CONSTANT: {
                    assert (this._replacement.equals(otherColInfo._replacement));
                    break;
                }
                case GLOBAL_MEAN: {
                    this._mean._sum *= (double)this._count;
                    this._mean._correction *= (double)this._count;
                    KahanPlus.getKahanPlusFnObject().execute((Data)this._mean, otherColInfo._mean._sum * (double)otherColInfo._count);
                    KahanPlus.getKahanPlusFnObject().execute((Data)this._mean, otherColInfo._mean._correction * (double)otherColInfo._count);
                    this._count += otherColInfo._count;
                    break;
                }
                case GLOBAL_MODE: {
                    if (this._hist == null) {
                        this._hist = new HashMap<String, Long>(otherColInfo._hist);
                        break;
                    }
                    this._hist.replaceAll((key, count) -> count + otherColInfo._hist.getOrDefault(key, 0L));
                    break;
                }
                default: {
                    throw new DMLRuntimeException("Method `" + this._method.name() + "` not supported for federated impute");
                }
            }
        }
    }

    public static enum MVMethod {
        INVALID,
        GLOBAL_MEAN,
        GLOBAL_MODE,
        CONSTANT;

    }
}

