package eu.mopso.tc.commands;

import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.opencsv.CSVReader;
import com.opencsv.CSVWriter;
import com.opencsv.exceptions.CsvValidationException;
import eu.mopso.tc.Env;
import eu.mopso.tc.FileUtil;
import eu.mopso.tc.HttpUtil;
import eu.mopso.tc.VersionProvider;
import eu.mopso.tc.models.ClassifyRequestBody;
import eu.mopso.tc.models.Prediction;
import java.io.BufferedReader;
import java.io.BufferedWriter;
import java.io.IOException;
import java.net.http.HttpResponse;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.CompletionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Semaphore;
import java.util.concurrent.atomic.AtomicBoolean;
import org.springframework.beans.propertyeditors.CustomBooleanEditor;
import picocli.CommandLine;

@CommandLine.Command(name = "classify", mixinStandardHelpOptions = true, description = {"Perform classification"}, versionProvider = VersionProvider.class)
/* loaded from: input_file:BOOT-INF/classes/eu/mopso/tc/commands/ClassifyCommand.class */
public class ClassifyCommand implements Callable<Integer> {
    private static final String ERROR_LABEL = "ERROR";
    private static final String OTHER_LABEL = "OTHER";
    private static final String SIMILARITY_SCORE_HEADER = "SIMILARITY_SCORE";
    private static final String CLASS_HEADER = "CLASS";
    private static final String L2_DISTANCE_ALGORITHM = "L2_DISTANCE";
    private static final String COSINE_SIMILARITY_ALGORITHM = "COSINE_SIMILARITY";
    private int BATCH_SIZE = 16;
    private final ObjectMapper jsonMapper = new ObjectMapper();
    private final AtomicBoolean cancellationToken = new AtomicBoolean(false);
    private final Object lock = new Object();

    @CommandLine.Option(names = {"-k", "--api-key"}, description = {"A registered api key. If not present the value from the env variable TC_API_KEY is used"})
    private String apiKey;

    @CommandLine.Option(names = {"-e", "--endpoint"}, description = {"Api endpoint."})
    private String apiEndpoint;

    @CommandLine.Option(names = {"-n", "--name"}, description = {"Model name."}, required = true)
    private String name;

    @CommandLine.Option(names = {"-I", "--index"}, description = {"The index in the csv file that contains the field to classify (from 0). By default, the index is 0"}, defaultValue = CustomBooleanEditor.VALUE_0)
    private int textIndex;

    @CommandLine.Option(names = {"-i", "--input"}, description = {"The input filename; “-” means stdin (e.g. -i - ). The file must be in CSV format"})
    private String inFilename;

    @CommandLine.Option(names = {"-o", "--output"}, description = {"The output filename; “-” means stdout (e.g. -o - )."})
    private String outFilename;

    @CommandLine.Option(names = {"-a", "--alg"}, description = {"Similarity algorithm to use. Possible values are 'L2_DISTANCE' and 'COSINE_SIMILARITY'. Default 'COSINE_SIMILARITY'"}, defaultValue = COSINE_SIMILARITY_ALGORITHM, hidden = true)
    private String algorithm;

    @CommandLine.Option(names = {"-S", "--strict"}, description = {"Runs the program in strict mode: any partially recoverable exception thrown during the execution (i.e. a classification that fails or a row that can't be parsed) will stop the program truncating the output to the last stable state. If not run in strict mode, the application will try to compensate for as many errors as it's possible."})
    private boolean isStrict;

    @CommandLine.Option(names = {"-H", "--header"}, description = {"If the flag is present the first line of the input file is copied to the output with added columns 'CLASS' and 'SIMILARITY_SCORE' . By default, it is assumed that the input has no header"})
    private boolean copyHeader;

    @CommandLine.Option(names = {"-t", "--threshold"}, description = {"Number between 0 (not included) and 1 (included) that is used to determine whether a match ‘SIMILARITY_SCORE’ is too low to be considered valid. In this case, the ‘CLASS’ is set to ‘OTHER’. Default value is 0.84"}, defaultValue = "0.84")
    private float threshold;

    @CommandLine.Option(names = {"-T", "--threads"}, description = {"The number of parallel jobs to be used by the classification services, by default is 1. If more than 1 is used the output order is not preserved. The value is capped to the number of cpu cores."}, defaultValue = "1")
    private int threads;

    @CommandLine.Option(names = {"--no-buffer"}, description = {"Execute the program in interactive mode. Will ignore --input, --output and --header options."})
    private boolean noBuffer;

    /* JADX WARN: Can't rename method to resolve collision */
    @Override // java.util.concurrent.Callable
    public Integer call() throws Exception {
        try {
            if (this.textIndex < 0) {
                throw new IllegalArgumentException("CSV column indexes cannot be negative");
            }
            if (!L2_DISTANCE_ALGORITHM.equals(this.algorithm) && !COSINE_SIMILARITY_ALGORITHM.equals(this.algorithm)) {
                throw new IllegalArgumentException("Invalid algorithm provided supported values: %s, %s".formatted(L2_DISTANCE_ALGORITHM, COSINE_SIMILARITY_ALGORITHM));
            }
            if (this.threads < 0) {
                throw new IllegalArgumentException("Threads number cannot be 0 or negative");
            }
            if (this.apiKey == null && Env.apiKey() == null) {
                throw new IllegalArgumentException("Missing required api-key, set TC_API_KEY or use --api-key option");
            }
            if (Env.apiKey() != null && this.apiKey == null) {
                this.apiKey = Env.apiKey();
            }
            if (this.apiKey.isBlank()) {
                throw new IllegalArgumentException("Malformed api-key");
            }
            if (this.apiEndpoint == null) {
                this.apiEndpoint = Env.apiEndpoint();
            }
            if (this.threads > Runtime.getRuntime().availableProcessors()) {
                throw new IllegalArgumentException(String.format("Too many threads count, for better performance please set a number between 1 and %s", Integer.valueOf(Runtime.getRuntime().availableProcessors())));
            }
            if (this.noBuffer) {
                this.BATCH_SIZE = 1;
                this.inFilename = "-";
                this.outFilename = "-";
                this.isStrict = true;
                this.copyHeader = false;
            }
            HttpUtil httpUtil = new HttpUtil(this.apiKey, this.apiEndpoint);
            TypeReference<List<Prediction>> typeReference = new TypeReference<List<Prediction>>() { // from class: eu.mopso.tc.commands.ClassifyCommand.1
            };
            ExecutorService newFixedThreadPool = Executors.newFixedThreadPool(this.threads);
            Semaphore semaphore = new Semaphore(this.threads);
            BufferedReader fileReader = FileUtil.getFileReader(this.inFilename);
            CSVReader cSVReader = new CSVReader(fileReader);
            BufferedWriter fileWriter = FileUtil.getFileWriter(this.outFilename);
            CSVWriter cSVWriter = new CSVWriter(fileWriter);
            if (this.copyHeader && hasNext(cSVReader)) {
                List<String> nextValue = nextValue(cSVReader);
                nextValue.add(CLASS_HEADER);
                nextValue.add(SIMILARITY_SCORE_HEADER);
                cSVWriter.writeNext((String[]) nextValue.toArray(new String[0]), false);
            }
            System.err.println("Starting classification");
            long currentTimeMillis = System.currentTimeMillis();
            ArrayList arrayList = new ArrayList();
            int i = 0;
            while (true) {
                if (!hasNext(cSVReader)) {
                    break;
                }
                ArrayList arrayList2 = new ArrayList(this.BATCH_SIZE);
                ArrayList arrayList3 = new ArrayList(this.BATCH_SIZE);
                int i2 = i;
                for (int i3 = 0; i3 < this.BATCH_SIZE && hasNext(cSVReader); i3++) {
                    List<String> nextValue2 = nextValue(cSVReader);
                    Prediction prediction = new Prediction();
                    prediction.setText(nextValue2.get(this.textIndex));
                    arrayList2.add(prediction);
                    arrayList3.add(nextValue2);
                }
                semaphore.acquire();
                if (this.cancellationToken.get()) {
                    System.err.println("Stop queueing new task");
                    break;
                }
                arrayList.add(CompletableFuture.runAsync(() -> {
                    try {
                        System.err.printf("Processing lines %d..%d.\n", Integer.valueOf(i2), Integer.valueOf(i2 + arrayList3.size()));
                        HttpResponse<String> makeClassifyRequest = makeClassifyRequest(arrayList2.stream().filter(prediction2 -> {
                            return !prediction2.getText().isBlank();
                        }).toList(), httpUtil);
                        if (makeClassifyRequest == null) {
                            handleErrorResult(arrayList3);
                        } else {
                            handleSuccessfulResult((List) this.jsonMapper.readValue((String) makeClassifyRequest.body(), typeReference), arrayList3);
                        }
                        writeAll(cSVWriter, arrayList3);
                        fileWriter.flush();
                        semaphore.release();
                    } catch (Exception e) {
                        semaphore.release();
                        if (this.isStrict) {
                            this.cancellationToken.set(true);
                        }
                        throw new CompletionException(e);
                    }
                }, newFixedThreadPool));
                i += arrayList3.size();
            }
            arrayList.forEach((v0) -> {
                v0.join();
            });
            fileReader.close();
            cSVReader.close();
            fileWriter.close();
            System.err.printf("Job terminated in %d ms\n", Long.valueOf(System.currentTimeMillis() - currentTimeMillis));
            return 0;
        } catch (Exception e) {
            System.err.printf("Something went wrong during classification: %s\n", e.getMessage());
            return 1;
        }
    }

    private void writeAll(CSVWriter cSVWriter, List<List<String>> list) {
        List<String[]> list2 = list.stream().map(list3 -> {
            return (String[]) list3.toArray(new String[0]);
        }).toList();
        synchronized (this.lock) {
            cSVWriter.writeAll(list2, false);
        }
    }

    private boolean hasNext(CSVReader cSVReader) throws IOException {
        return cSVReader.peek() != null;
    }

    private List<String> nextValue(CSVReader cSVReader) throws IOException, CsvValidationException {
        return new ArrayList(Arrays.asList(cSVReader.readNext()));
    }

    private void handleSuccessfulResult(List<Prediction> list, List<List<String>> list2) {
        int i = 0;
        for (int i2 = 0; i2 < list2.size() && i < list.size(); i2++) {
            List<String> list3 = list2.get(i2);
            Prediction prediction = list.get(i);
            if (list3.get(this.textIndex).isBlank()) {
                list3.add(ERROR_LABEL);
                list3.add("");
            } else {
                float similarity = prediction.getSimilarity() * prediction.getWeight();
                list3.add(similarity < this.threshold ? OTHER_LABEL : prediction.getLabel());
                list3.add(String.valueOf(similarity));
                i++;
            }
        }
    }

    private void handleErrorResult(List<List<String>> list) {
        for (List<String> list2 : list) {
            list2.add(ERROR_LABEL);
            list2.add("");
        }
    }

    private HttpResponse<String> makeClassifyRequest(List<Prediction> list, HttpUtil httpUtil) throws Exception {
        HttpResponse<String> classify = httpUtil.classify(this.jsonMapper.writeValueAsString(new ClassifyRequestBody(this.name, this.algorithm, list)));
        if (classify.statusCode() < 400) {
            return classify;
        }
        System.err.printf("Encountered error: %d %s\n", Integer.valueOf(classify.statusCode()), classify.body());
        if (this.isStrict || classify.statusCode() == HttpUtil.QUOTA_EXCEEDED_STATUS_CODE) {
            throw new IOException((String) classify.body());
        }
        return null;
    }
}
