/*
 * Decompiled with CFR 0.152.
 */
package org.thingsboard.trendz.service.model.prediction.methods;

import lombok.Generated;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.springframework.stereotype.Service;
import org.thingsboard.trendz.exception.TrendzInternalException;
import org.thingsboard.trendz.service.model.prediction.methods.ArimaMethodParameters;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethod;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethodParameters;
import org.thingsboard.trendz.service.model.prediction.methods.PredictionMethodType;

@Service
public class ArimaMethod
implements PredictionMethod {
    @Generated
    private static final Logger log = LoggerFactory.getLogger(ArimaMethod.class);
    private static final int MAX_POINTS = 2048;

    public PredictionMethodType getType() {
        return PredictionMethodType.ARIMA;
    }

    public String getMethodDefinition(PredictionMethodParameters methodParameters) {
        ArimaMethodParameters parameters = (ArimaMethodParameters)methodParameters;
        int p = parameters.getP();
        int d = parameters.getD();
        int q = parameters.getQ();
        if (p < 0 || d < 0 || q < 0) {
            throw new TrendzInternalException("Forbidden values!");
        }
        return "#####################################################\n# Prediction Method: ARIMA\n\nfrom sklearn.preprocessing import MinMaxScaler, StandardScaler\nfrom statsmodels.tsa.arima.model import ARIMA\nimport numpy as np\nimport pickle\n\n\nclass CustomModel(IModel):\n    def __init__(self, value_transformer=None, timestamp_transformer=None, order=(%s, %s, %s)):\n        self.model = None\n        self.order = order\n        self.timestamp_transformer = timestamp_transformer if timestamp_transformer else StandardScaler()\n        self.value_transformer = value_transformer if value_transformer else MinMaxScaler()\n        self.timestamps = np.array([])\n        self.values = np.array([])\n\n    def init_state(self):\n        self.timestamps = np.array([])\n        self.values = np.array([])\n\n    def train(self, data, additionalData=None):\n        # Prepare\n        self.values = np.array([point[1] for point in data]).reshape(-1, 1)\n        self.timestamps = np.array([point[0] for point in data]).reshape(-1, 1)\n        self.value_transformer.fit(self.values)\n        self.timestamp_transformer.fit(self.timestamps)\n        values_scaled = self.value_transformer.transform(self.values).flatten()\n        timestamps_scaled = self.timestamp_transformer.transform(self.timestamps).flatten()\n        print(\"ARIMA: prepared data\")\n\n        # Fit\n        print(values_scaled)\n        print(self.order)\n        print(timestamps_scaled)\n        self.model = ARIMA(values_scaled, order=self.order, exog=timestamps_scaled).fit()\n        print(\"ARIMA: fitted\")\n\n    def partial_fit(self, data, additionalData=None):\n        # Prepare\n        new_values = np.array([point[1] for point in data]).reshape(-1, 1)\n        new_timestamps = np.array([point[0] for point in data]).reshape(-1, 1)\n        self.values = np.concatenate((self.values, new_values), axis=0) if self.values.size else new_values\n        self.timestamps = np.concatenate((self.timestamps, new_timestamps), axis=0) if self.timestamps.size else new_timestamps\n\n        if self.values.shape[0] > %d:\n            self.values = self.values[-%d:]\n            self.timestamps = self.timestamps[-%d:]\n\n        self.value_transformer.fit(self.values)\n        self.timestamp_transformer.fit(self.timestamps)\n        values_scaled = self.value_transformer.transform(self.values).flatten()\n        timestamps_scaled = self.timestamp_transformer.transform(self.timestamps).flatten()\n\n        # Fit\n        self.model = ARIMA(values_scaled, order=self.order, exog=timestamps_scaled).fit()\n\n    def predict(self, timestamps):\n        ts = np.array(timestamps).reshape(-1, 1)\n        ts_scaled = self.timestamp_transformer.transform(ts).flatten()\n        n = len(timestamps)\n        forecast = self.model.forecast(steps=n, exog=ts_scaled)\n        forecast = self.value_transformer.inverse_transform(forecast.reshape(-1, 1)).flatten()\n\n        predictions = list(zip(timestamps, forecast))\n        return predictions\n\n    def save_state(self, file_path):\n        with open(file_path, 'wb') as f:\n            pickle.dump({\n                'model': self.model,\n                'timestamps': self.timestamps,\n                'values': self.values,\n                'value_transformer': self.value_transformer,\n                'timestamp_transformer': self.timestamp_transformer\n            }, f)\n\n    def load_state(self, file_path):\n        with open(file_path, 'rb') as f:\n            state = pickle.load(f)\n            self.model = state['model']\n            self.values = state['values']\n            self.timestamps = state['timestamps']\n            self.value_transformer = state['value_transformer']\n            self.timestamp_transformer = state['timestamp_transformer']\n\n    def name(self):\n        return \"ARIMAModel\"\n\n#####################################################\n".formatted(p, d, q, 2048, 2048, 2048);
    }
}

