/*
 * Decompiled with CFR 0.152.
 */
package ai.djl.modality.rl.agent;

import ai.djl.Device;
import ai.djl.modality.rl.ActionSpace;
import ai.djl.modality.rl.agent.RlAgent;
import ai.djl.modality.rl.env.RlEnv;
import ai.djl.ndarray.NDArray;
import ai.djl.ndarray.NDList;
import ai.djl.training.GradientCollector;
import ai.djl.training.Trainer;
import ai.djl.training.listener.TrainingListener;
import ai.djl.translate.Batchifier;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Stream;

public class QAgent
implements RlAgent {
    private Trainer trainer;
    private float rewardDiscount;
    private Batchifier batchifier;

    public QAgent(Trainer trainer, float rewardDiscount) {
        this(trainer, rewardDiscount, Batchifier.STACK);
    }

    public QAgent(Trainer trainer, float rewardDiscount, Batchifier batchifier) {
        this.trainer = trainer;
        this.rewardDiscount = rewardDiscount;
        this.batchifier = batchifier;
    }

    @Override
    public NDList chooseAction(RlEnv env, boolean training) {
        ActionSpace actionSpace = env.getActionSpace();
        NDList[] inputs = this.buildInputs(env.getObservation(), actionSpace);
        NDArray actionScores = this.trainer.evaluate(this.batchifier.batchify(inputs)).singletonOrThrow().squeeze(-1);
        int bestAction = Math.toIntExact(actionScores.argMax().getLong(new long[0]));
        return (NDList)actionSpace.get(bestAction);
    }

    @Override
    public void trainBatch(RlEnv.Step[] batchSteps) {
        TrainingListener.BatchData batchData = new TrainingListener.BatchData(null, new ConcurrentHashMap<Device, NDList>(), new ConcurrentHashMap<Device, NDList>());
        for (RlEnv.Step step : batchSteps) {
            NDList[] preInput = this.buildInputs(step.getPreObservation(), Collections.singletonList(step.getAction()));
            NDList[] postInputs = this.buildInputs(step.getPostObservation(), step.getPostActionSpace());
            NDList[] allInputs = (NDList[])Stream.concat(Arrays.stream(preInput), Arrays.stream(postInputs)).toArray(NDList[]::new);
            try (GradientCollector collector = this.trainer.newGradientCollector();){
                NDList postQ;
                NDArray results = this.trainer.forward(this.batchifier.batchify(allInputs)).singletonOrThrow().squeeze(-1);
                NDList preQ = new NDList(results.get(0L));
                if (step.isDone()) {
                    postQ = new NDList(step.getReward());
                } else {
                    NDArray bestAction = results.get("1:", new Object[0]).max();
                    postQ = new NDList(bestAction.mul(Float.valueOf(this.rewardDiscount)).add(step.getReward()));
                }
                NDArray lossValue = this.trainer.getLoss().evaluate(postQ, preQ);
                collector.backward(lossValue);
                batchData.getLabels().put(((NDArray)postQ.get(0)).getDevice(), postQ);
                batchData.getPredictions().put(((NDArray)preQ.get(0)).getDevice(), preQ);
            }
        }
        this.trainer.notifyListeners(listener -> listener.onTrainingBatch(this.trainer, batchData));
    }

    private NDList[] buildInputs(NDList observation, List<NDList> actions) {
        NDList[] inputs = new NDList[actions.size()];
        for (int i = 0; i < actions.size(); ++i) {
            NDList nextData;
            inputs[i] = nextData = new NDList().addAll(observation).addAll(actions.get(i));
        }
        return inputs;
    }
}

