/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.trendz.service.predict.executors;

import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.type.CollectionType;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.thingsboard.trendz.domain.base.Point;
import org.thingsboard.trendz.domain.base.TimeSeries;
import org.thingsboard.trendz.domain.definition.view.PredictionType;
import org.thingsboard.trendz.security.entity.JwtSecurityUser;
import org.thingsboard.trendz.service.predict.PredictionMethodData;
import org.thingsboard.trendz.service.predict.PredictionMethodExecutor;
import org.thingsboard.trendz.service.script.engine.PythonScriptEngine;
import org.thingsboard.trendz.tools.json.JsonUtils;
import reactor.core.publisher.Mono;

@Service
public class SarimaxMethod
implements PredictionMethodExecutor {
    private static final Logger log = LoggerFactory.getLogger(SarimaxMethod.class);
    private final PythonScriptEngine pythonScriptEngine;

    @Autowired
    public SarimaxMethod(PythonScriptEngine pythonScriptEngine) {
        this.pythonScriptEngine = pythonScriptEngine;
    }

    public PredictionType getPredictiontype() {
        return PredictionType.SARIMAX;
    }

    public Mono<TimeSeries> predict(JwtSecurityUser user, TimeSeries learnSet, long[] predictionTimes, PredictionMethodData methodData) {
        return Mono.just((Object)new Object()).flatMap(o -> {
            List inputX = learnSet.getPoints().stream().map(Point::getTs).collect(Collectors.toList());
            List inputY = learnSet.getPoints().stream().map(Point::getValue).collect(Collectors.toList());
            ArrayList outputY = new ArrayList();
            HashMap<String, Object> input = new HashMap<String, Object>();
            input.put("inputX", inputX);
            input.put("inputY", inputY);
            input.put("outputX", predictionTimes);
            input.put("outputY", outputY);
            CollectionType type = JsonUtils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, Double.class);
            String script = "import numpy as np\nimport pandas as pd\nfrom statsmodels.tsa.statespace.sarimax import SARIMAX\nfrom pmdarima.arima import auto_arima\n\nprint(f\"inputX: {inputX}\")\nprint(f\"inputY: {inputY}\")\nprint(f\"outputX: {outputX}\")\nprint(f\"outputY: {outputY}\")\n\ninputSeries = pd.Series(inputY, index=pd.to_datetime(inputX))\n\nmodel = auto_arima(inputSeries, seasonal=True, m=1)\nmodel_fit = SARIMAX(inputSeries, order=model.order, seasonal_order=model.seasonal_order).fit()\n\noutputSeries = model_fit.forecast(steps=len(outputX))\n\noutputY = outputSeries.values.tolist()\nprint(f\"result: {outputY}\")\nreturn outputY\n";
            return this.pythonScriptEngine.runScript(user, script, (JavaType)type, input).map(executionResult -> {
                List output = (List)executionResult.getResult();
                if (predictionTimes.length != output.size()) {
                    throw new RuntimeException("Prediction sets are not corresponding");
                }
                TimeSeries resultPoints = IntStream.range(0, predictionTimes.length).mapToObj(i -> new Point(predictionTimes[i], ((Double)output.get(i)).doubleValue())).sorted(Comparator.comparingLong(Point::getTs)).collect(Collectors.collectingAndThen(Collectors.toList(), TimeSeries::new));
                return resultPoints;
            });
        });
    }
}

