package eu.mopso.tc.commands;

import com.fasterxml.jackson.databind.ObjectMapper;
import com.opencsv.CSVReader;
import com.opencsv.exceptions.CsvValidationException;
import eu.mopso.tc.Env;
import eu.mopso.tc.FileUtil;
import eu.mopso.tc.HttpUtil;
import eu.mopso.tc.VersionProvider;
import eu.mopso.tc.models.BuildModelRequestBody;
import eu.mopso.tc.models.Prediction;
import java.io.BufferedReader;
import java.io.IOException;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import org.apache.logging.log4j.message.StructuredDataId;
import org.springframework.beans.propertyeditors.CustomBooleanEditor;
import picocli.CommandLine;

@CommandLine.Command(name = "train", mixinStandardHelpOptions = true, description = {"Train the model"}, versionProvider = VersionProvider.class)
/* loaded from: input_file:BOOT-INF/classes/eu/mopso/tc/commands/ModelTrainCommand.class */
public class ModelTrainCommand implements Callable<Integer> {

    @CommandLine.Option(names = {"-k", "--api-key"}, description = {"A registered api key. If not present the value from the env variable TC_API_KEY is used"})
    private String apiKey;

    @CommandLine.Option(names = {"-e", "--endpoint"}, description = {"Api endpoint."})
    private String apiEndpoint;

    @CommandLine.Option(names = {"-n", "--name"}, description = {"Name of the model to train, create a new model if name is not found."}, required = true)
    private String name;

    @CommandLine.Option(names = {"-i", "--input"}, description = {"The file containing the training data, by default “-” that means std in (e.g. -i - ). The stream is supposed to be in CSV format and MUST contain two fields (“prototype” and “class”) plus an optional field “weight” ranging from 0 to 1 (1 by default). Maximum allowed length for classes is 2048 characters."})
    private String inputFilename;

    @CommandLine.Option(names = {"-H", "--header"}, description = {"Will ignore the first line of the csv file."})
    private boolean skipHeader;

    @CommandLine.Option(names = {"-P", "--prototype-index"}, description = {"The column index in the csv file that contains the field with the text to classify (from 0). Default to 0"}, defaultValue = CustomBooleanEditor.VALUE_0)
    private int textIndex;

    @CommandLine.Option(names = {"-C", "--class-index"}, description = {"The column index in the csv file that contains the field with the class attached to the text (from 0). Default to 1."}, defaultValue = "1")
    private int labelIndex;

    @CommandLine.Option(names = {"-W", "--weight-index"}, description = {"The column index in the csv file that contains the field with the classification weight (from 0). Set to -1 if not present"}, defaultValue = StructuredDataId.RESERVED)
    private int weightIndex;
    private ObjectMapper jsonMapper = new ObjectMapper();
    private final int BATCH_SIZE = 16;
    private final float DEFAULT_WEIGHT = 1.0f;

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public Integer call() throws Exception {
        try {
            System.err.println("Starting training");
            if (this.textIndex < 0 || this.labelIndex < 0) {
                throw new Exception("CSV column indexes cannot be negative");
            }
            if (this.apiKey == null && Env.apiKey() == null) {
                throw new IllegalArgumentException("Missing required api-key, set TC_API_KEY or use --api-key option");
            }
            if (Env.apiKey() != null && this.apiKey == null) {
                this.apiKey = Env.apiKey();
            }
            if (this.apiKey.isBlank()) {
                throw new IllegalArgumentException("Malformed api-key");
            }
            if (this.apiEndpoint == null) {
                this.apiEndpoint = Env.apiEndpoint();
            }
            HttpUtil httpUtil = new HttpUtil(this.apiKey, this.apiEndpoint);
            BufferedReader fileReader = FileUtil.getFileReader(this.inputFilename);
            CSVReader cSVReader = new CSVReader(fileReader);
            if (this.skipHeader && hasNext(cSVReader)) {
                cSVReader.skip(1);
            }
            int i = 0;
            while (hasNext(cSVReader)) {
                ArrayList arrayList = new ArrayList(16);
                int i2 = 0;
                while (i2 < 16 && hasNext(cSVReader)) {
                    List<String> nextValue = nextValue(cSVReader);
                    if (nextValue.get(this.textIndex).isBlank() || nextValue.get(this.labelIndex).isBlank()) {
                        System.err.printf("Line %d is empty, skipped\n", Integer.valueOf(i + i2));
                        i2--;
                    } else {
                        Prediction prediction = new Prediction();
                        prediction.setText(nextValue.get(this.textIndex));
                        prediction.setLabel(nextValue.get(this.labelIndex));
                        prediction.setWeight(this.weightIndex > 0 ? Float.parseFloat(nextValue.get(this.weightIndex)) : 1.0f);
                        arrayList.add(prediction);
                    }
                    i2++;
                }
                System.err.printf("Processing lines %d..%d.\n", Integer.valueOf(i), Integer.valueOf(i + arrayList.size()));
                makeTrainRequest(arrayList, httpUtil);
                i += arrayList.size();
            }
            System.err.printf("Model training completed. Model name: %s\n", this.name);
            fileReader.close();
            cSVReader.close();
            return 0;
        } catch (Exception e) {
            System.err.printf("Something went wrong during training: %s\n", e.getMessage());
            return 1;
        }
    }

    private boolean hasNext(CSVReader cSVReader) throws IOException {
        return cSVReader.peek() != null;
    }

    private List<String> nextValue(CSVReader cSVReader) throws IOException, CsvValidationException {
        return new ArrayList(Arrays.asList(cSVReader.readNext()));
    }

    private void makeTrainRequest(List<Prediction> list, HttpUtil httpUtil) throws Exception {
        HttpResponse<String> trainModel = httpUtil.trainModel(this.jsonMapper.writeValueAsString(new BuildModelRequestBody(this.name, list)));
        if (trainModel.statusCode() >= 400) {
            throw new Exception((String) trainModel.body());
        }
    }
}
