/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.trendz.service.model.prediction.handler;

import com.fasterxml.jackson.databind.JavaType;
import com.fasterxml.jackson.databind.type.CollectionType;
import java.lang.invoke.CallSite;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.UUID;
import java.util.stream.Collectors;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.stereotype.Service;
import org.thingsboard.trendz.domain.base.Point;
import org.thingsboard.trendz.domain.base.TimeSeries;
import org.thingsboard.trendz.security.entity.JwtSecurityUser;
import org.thingsboard.trendz.service.model.prediction.PredictionModelDefinitionService;
import org.thingsboard.trendz.service.model.prediction.handler.PredictionModelHandlerInterface;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethod;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethodParameters;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethodType;
import org.thingsboard.trendz.service.script.engine.PythonScriptEngine;
import org.thingsboard.trendz.service.script.engine.ScriptExecutionResult;
import org.thingsboard.trendz.tools.json.JsonUtils;
import reactor.core.publisher.Mono;

@Service
public class PythonPredictionModelHandler
implements PredictionModelHandlerInterface {
    private static final Logger log = LoggerFactory.getLogger(PythonPredictionModelHandler.class);
    private static final String PYTHON_WORKING_DIRECTORY_PATH = "/tmp/python_executor_temp_files";
    private final PredictionModelDefinitionService predictionModelDefinitionService;
    private final PythonScriptEngine pythonScriptEngine;

    @Autowired
    public PythonPredictionModelHandler(PredictionModelDefinitionService predictionModelDefinitionService, PythonScriptEngine pythonScriptEngine) {
        this.predictionModelDefinitionService = predictionModelDefinitionService;
        this.pythonScriptEngine = pythonScriptEngine;
    }

    public Mono<String> createModelState(JwtSecurityUser user, PredictionMethodParameters parameters) {
        String executionId = this.makeExecutionId();
        PredictionMethodType type = parameters.getType();
        PredictionMethod predictionMethod = this.predictionModelDefinitionService.getPredictionMethodByType(type);
        String methodScript = predictionMethod.getMethodDefinition(parameters);
        String finalScript = predictionMethod.getMethodInterface() + methodScript + "\nimport os\nimport base64\nworking_directory = \"%s\"\nif not os.path.exists(working_directory):\n    os.makedirs(working_directory)\n\nfile_path = working_directory + \"/model_state_%s\"\n\nfrom sklearn.preprocessing import MinMaxScaler\nfrom sklearn.preprocessing import StandardScaler\nvalue_transformer =  MinMaxScaler()\ntimestamp_transformer = StandardScaler()\n\nmodel = CustomModel(value_transformer, timestamp_transformer)\nmodel.init_state()\nmodel.save_state(file_path)\n\n\nwith open(file_path, 'rb') as file:\n    file_content = file.read()\n    file_encoded_content = base64.encodebytes(file_content).decode('utf-8')\nos.remove(file_path)\n\nreturn file_encoded_content\n".formatted(PYTHON_WORKING_DIRECTORY_PATH, executionId);
        JavaType returnType = JsonUtils.getObjectMapper().getTypeFactory().constructType(String.class);
        return this.pythonScriptEngine.runScript(user, finalScript, returnType, Collections.emptyMap()).map(scriptExecutionResult -> scriptExecutionResult.getResult().toString());
    }

    public Mono<String> fitModel(JwtSecurityUser user, PredictionMethodParameters parameters, String state, TimeSeries inputTelemetry, Map<UUID, TimeSeries> additionalTelemetryMap) {
        return this.fit(user, parameters, state, inputTelemetry, additionalTelemetryMap, false);
    }

    public Mono<String> partialFitModel(JwtSecurityUser user, PredictionMethodParameters parameters, String state, TimeSeries inputTelemetry, Map<UUID, TimeSeries> additionalTelemetryMap) {
        return this.fit(user, parameters, state, inputTelemetry, additionalTelemetryMap, true);
    }

    public Mono<TimeSeries> predict(JwtSecurityUser user, PredictionMethodParameters parameters, String state, List<Long> outputX) {
        String executionId = this.makeExecutionId();
        PredictionMethodType type = parameters.getType();
        PredictionMethod predictionMethod = this.predictionModelDefinitionService.getPredictionMethodByType(type);
        String methodScript = predictionMethod.getMethodDefinition(parameters);
        HashMap<String, Object> input = new HashMap<String, Object>();
        input.put("outputX", outputX);
        input.put("internal_state", state);
        String finalScript = predictionMethod.getMethodInterface() + methodScript + "\nimport os\nimport base64\n\nworking_directory = \"%s\"\nif not os.path.exists(working_directory):\n    os.makedirs(working_directory)\n\nfile_path = working_directory + \"/model_state_%s\"\n\ninternal_state_bytes = base64.decodebytes(internal_state.encode('utf-8'))\nwith open(file_path, 'wb') as file:\n    file.write(internal_state_bytes)\n\nmodel = CustomModel()\nmodel.load_state(file_path)\nos.remove(file_path)\n\noutputY = model.predict(outputX)\nreturn outputY\n".formatted(PYTHON_WORKING_DIRECTORY_PATH, executionId);
        CollectionType parsingListType = JsonUtils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, Double.class);
        CollectionType parsingType = JsonUtils.getObjectMapper().getTypeFactory().constructCollectionType(List.class, (JavaType)parsingListType);
        return this.pythonScriptEngine.runScript(user, finalScript, (JavaType)parsingType, input).map(ScriptExecutionResult::getResult).flatMap(result -> {
            List output = (List)result;
            if (outputX.size() != output.size()) {
                return Mono.error((Throwable)new RuntimeException("Prediction sets are not corresponding"));
            }
            TimeSeries series = output.stream().map(i -> new Point(((Double)i.get(0)).longValue(), ((Double)i.get(1)).doubleValue())).sorted(Comparator.comparingLong(Point::getTs)).collect(Collectors.collectingAndThen(Collectors.toList(), TimeSeries::new));
            return Mono.just((Object)series);
        });
    }

    private Mono<String> fit(JwtSecurityUser user, PredictionMethodParameters parameters, String state, TimeSeries inputTelemetry, Map<UUID, TimeSeries> additionalTelemetryMap, boolean partial) {
        String executionId = this.makeExecutionId();
        PredictionMethodType type = parameters.getType();
        PredictionMethod predictionMethod = this.predictionModelDefinitionService.getPredictionMethodByType(type);
        String methodScript = predictionMethod.getMethodDefinition(parameters);
        inputTelemetry.sort();
        List data = inputTelemetry.getPoints().stream().map(point -> List.of(Long.valueOf(point.getTs()), Double.valueOf(point.getValue()))).collect(Collectors.toList());
        HashMap additionalData = new HashMap();
        if (additionalTelemetryMap != null && !additionalTelemetryMap.isEmpty()) {
            for (UUID telemetryKey : additionalTelemetryMap.keySet()) {
                TimeSeries additionalTelemetry = additionalTelemetryMap.get(telemetryKey);
                String regressorName = "telemetry_" + telemetryKey.toString().replaceAll("-", "_");
                List additionalTelemetryValues = additionalTelemetry.getPoints().stream().sorted(Comparator.comparingLong(Point::getTs)).map(point -> List.of(Long.valueOf(point.getTs()), Double.valueOf(point.getValue()))).collect(Collectors.toList());
                additionalData.put((CallSite)((Object)regressorName), additionalTelemetryValues);
            }
        }
        HashMap<String, Object> input = new HashMap<String, Object>();
        input.put("data", data);
        input.put("additionalData", additionalData);
        input.put("internal_state", state);
        String finalScript = predictionMethod.getMethodInterface() + methodScript + "\nimport os\nimport base64\nimport pandas as pd\n\nworking_directory = \"%s\"\nif not os.path.exists(working_directory):\n    os.makedirs(working_directory)\n\nfile_path = working_directory + \"/model_state_%s\"\n\ninternal_state_bytes = base64.decodebytes(internal_state.encode('utf-8'))\nwith open(file_path, 'wb') as file:\n    file.write(internal_state_bytes)\nprint(\"Model data is written\")\n\nmodel = CustomModel()\nprint(\"Model is created\")\nmodel.load_state(file_path)\nos.remove(file_path)\nprint(\"Model is loaded from data\")\n\nmodel.%s(data, additionalData)\nprint(\"Model action is performed\")\n\nmodel.save_state(file_path)\nprint(\"Model is saved to data\")\n\nfile = open(file_path, 'rb')\nfile_content = file.read()\nfile_encoded_content = base64.encodebytes(file_content).decode('utf-8')\nos.remove(file_path)\nprint(\"Model data is loaded\")\n\nreturn file_encoded_content\n".formatted(PYTHON_WORKING_DIRECTORY_PATH, executionId, partial ? "partial_fit" : "train");
        JavaType returnType = JsonUtils.getObjectMapper().getTypeFactory().constructType(String.class);
        return this.pythonScriptEngine.runScript(user, finalScript, returnType, input).map(scriptExecutionResult -> scriptExecutionResult.getResult().toString());
    }

    private String makeExecutionId() {
        return UUID.randomUUID().toString().replace('-', '_');
    }
}

