import React, {useEffect, useState} from "react";
import {useHistory} from "react-router-dom";
import {useSnackbar} from "notistack";
import Grid from "@mui/material/Grid";
import Tooltip from "@mui/material/Tooltip";
import IconButton from "@mui/material/IconButton";
import NextIcon from "@mui/icons-material/NavigateNext";
import ViewIcon from "@mui/icons-material/Visibility";
import HideIcon from '@mui/icons-material/VisibilityOff';
import DownloadIcon from "@mui/icons-material/CloudDownload";
import Typography from "@mui/material/Typography";
import ImageSamples from "./ImageSamples";
import {useReports} from "./Reports";
import {checkError, errOptions, getLogger, infoOptions} from "../util/util";

const log = getLogger("report.viewer");

function sortRefsByMetric(refs, benign_metrics, attacked_metrics) {
    let advcount = null;
    try {
        const metric_name = Object.keys(attacked_metrics)[0];
        if (metric_name != null) {
            if (!["mIoU", "AP", "AP50"].includes(metric_name)) {
                log.debug("Sorting by metric " + metric_name);
            }
            const attacked_metric = attacked_metrics[metric_name];
            const benign_metric = benign_metrics[metric_name];
            const indexes = Array.from(attacked_metric.keys())
                .sort((a, b) => {
                    // Note: when deciding on adversarial samples if a defense is applied,
                    // ideally the benign metric should be from
                    // the performance when only defense is applied, but we don't have that here.
                    return Math.abs(benign_metric[b] - attacked_metric[b]) -
                        Math.abs(benign_metric[a] - attacked_metric[a]);
                });
            advcount = 0;
            indexes.forEach(i => {
                if (Math.abs(benign_metric[i] - attacked_metric[i]) > 0.3) {
                    advcount++;
                }
            });

            log.debug("Found " + advcount + " adversarial samples");
            refs = indexes.map(i => {
                refs[i].metric_name = metric_name;
                refs[i].attacked_metric = attacked_metric[i];
                refs[i].benign_metric = benign_metric[i];
                return refs[i]
            });
        } else {
            log.debug("No metric found to sort on.");
        }
    } catch (e) {
        log.error("Failed to sort by metric", e);
    }
    return [refs, advcount]
}

export default function ReportSampleViewer({title, id}) {
    const history = useHistory();
    const {enqueueSnackbar} = useSnackbar();
    const {getReportMetaData, getSampleData, getReportFile, downloadReportData} = useReports();

    const [images, setImages] = useState([]);
    const [imageCnt, setImageCnt] = useState(0);
    const [fileRefs, setFileRefs] = useState(null);
    const [refIsPath, setRefIsPath] = useState(false);
    const [hideSamples, setHideSamples] = useState(false);
    const [adversarialCnt, setAdversarialCnt] = useState(null);

    useEffect(() => {
        getReportMetaData(id).then(data => {
            setImageCnt(0);
            setImages([]);
            if (data != null) {
                if (data.version === 2) {
                    if (data.file_refs != null) {
                        setRefIsPath(true);
                        let refs = Object.values(data.file_refs);
                        refs = refs.map(r => {return {path: r}});
                        if (data.attacked != null && data.attacked.length > 0) {
                            const [sortedRefs, advCnt] = sortRefsByMetric(refs, data.benign[0], data.attacked[0]);
                            setAdversarialCnt(advCnt);
                            refs = sortedRefs;
                        }
                        setFileRefs(refs);
                    }
                } else {
                    setFileRefs(data.file_refs);
                }
            }
        });
    }, [id, getReportMetaData])

    const handleView = async (tag, imageCnt) => {
        try {
            let idx = images.length;
            let max = 1000;
            let ref;
            if (fileRefs != null) { // we have metadata in this report
                if (images.length >= fileRefs.length) {
                    enqueueSnackbar('No more images available.', infoOptions);
                    return;
                }
                ref = fileRefs[images.length];
                idx = 0;
                max = fileRefs.length;
            }
            let data;
            if (refIsPath) {
                data = await getReportFile(id, ref.path);
            } else {
                data = await getSampleData(id, idx, ref);
            }
            images.push({img: URL.createObjectURL(data),
                title: "image_" + images.length, metric_name: ref.metric_name,
                attacked_metric: ref.attacked_metric, benign_metric: ref.benign_metric});
            setImages([...images]);
            if (images.length < imageCnt) {
                imageCnt = Math.min(max, imageCnt);
                setImageCnt(imageCnt);
                handleView(tag, imageCnt);
            }
        } catch (e) {
            if (e.response != null && e.response.status === 404) {
                if (tag === 'attacked') {
                    handleView('benign', imageCnt);
                } else {
                    enqueueSnackbar('No more images found.', errOptions);
                }
            } else {
                checkError(e, history, () => enqueueSnackbar('Failed to view more files', errOptions))
            }
        }
    };

    return <><Grid container direction={"column"}>
        <Grid item>
            {title != null &&
                <Typography variant="body1" component="span">{title}</Typography>
            }
            <Tooltip title={images.length > 0 ? "Load more samples" : "View samples"} aria-label="View samples">
                <IconButton aria-label="view" size="small" onClick={() => {
                    if (hideSamples) {
                        setHideSamples(false);
                    } else {
                        handleView('attacked', imageCnt + 20);
                    }
                }}>
                    {images.length > 0 ? <NextIcon/> : <ViewIcon/>}
                </IconButton>
            </Tooltip>
            {!hideSamples && images.length > 0 &&
                <Tooltip title={"Hide samples"} aria-label="Hide samples">
                    <IconButton aria-label="hide" size="small"  onClick={() => {
                        setHideSamples(true);
                    }}>
                        <HideIcon/>
                    </IconButton>
                </Tooltip>
            }
            <Tooltip title={"Download samples"} aria-label="Download samples">
                <IconButton aria-label="download" size="small"  onClick={() => downloadReportData(id)}>
                    <DownloadIcon/>
                </IconButton>
            </Tooltip>
            {adversarialCnt != null && adversarialCnt > 0 &&
                <Typography variant="body1" component="span"> (Approximately {adversarialCnt} adversarial samples)</Typography>
            }
        </Grid>
        <Grid item>
            {!hideSamples && images.length > 0 && <ImageSamples cols={10} items={images} imwidth={32} imheight={32}/>}
        </Grid>
    </Grid>
    </>
}
