import io
from collections import Counter
from pathlib import Path
from unittest.mock import MagicMock, patch, sentinel

import pyasn
import pytest

from asncounter import (
    BaseLineCollector,
    CollectionsRecorder,
    LineCollector,
    PrometheusRecorder,
    TcpdumpLineCollector,
    TupleCollector,
    args,
    main,
)

# this uses AS64496 from RFC5398, 192.0.2.0/24 and 203.0.113.0/24 from
# RFC5737, and 2001:DB8::/32 from RFC3849
SAMPLE_DB = """\
192.0.2.0/24\t64496
2001:DB8::/48\t64496
2001:DB8:1::/48\t64497
"""

TEST_LINE_ADDRESSES = """\
192.0.2.1 # comment
192.0.2.2
203.0.113.42
2001:DB8::1
2001:DB8:1::1
2001:DB8:2::1
# comment

garbage
"""

TEST_LINE_ASN = Counter({None: 2, 64496: 3, 64497: 1})
TEST_LINE_PREFIXES = Counter(
    {None: 2, "192.0.2.0/24": 2, "2001:db8:1::/48": 1, "2001:db8::/48": 1}
)

TEST_TUPLE_ADDRESSES = """\
192.0.2.1 1 # comment
192.0.2.2 3.4
192.0.2.2 2.5
203.0.113.42 3
2001:DB8::1 0
2001:DB8:1::1 1337.3
2001:DB8:2::1 1
# comment

garbage
"""


@pytest.fixture
def recorder_setup() -> CollectionsRecorder:
    args.no_asn = False
    args.no_prefixes = False
    args.no_resolve_asn = True
    args.top = 10

    # required for lookups operations, should probably be moved to the recorder
    recorder = CollectionsRecorder()
    # test for division by zero
    recorder.display_results()
    recorder.asndb = pyasn.pyasn(None, ipasn_string=SAMPLE_DB)
    return recorder


def test_line_collectors(recorder_setup: CollectionsRecorder) -> None:
    recorder = recorder_setup
    stream = io.StringIO(TEST_LINE_ADDRESSES)
    collector: BaseLineCollector = LineCollector(recorder=recorder)
    collector.collect(stream)
    assert recorder.asn_counter == TEST_LINE_ASN
    assert recorder.prefix_counter == TEST_LINE_PREFIXES
    stream = io.StringIO(
        "10:41:33.342516 IP 192.0.2.3.65278 > 192.0.2.255.59387: UDP, length 42\n",
    )
    collector = TcpdumpLineCollector(recorder=recorder)
    collector.collect(stream)
    assert recorder.asn_counter == Counter({None: 2, 64496: 4, 64497: 1})
    assert recorder.prefix_counter == Counter(
        {None: 2, "192.0.2.0/24": 3, "2001:db8:1::/48": 1, "2001:db8::/48": 1}
    )
    recorder.display_results()
    assert recorder.asn_prefixes(64496) == {"192.0.2.0/24", "2001:db8::/48"}
    assert recorder.asn_prefixes(64496, 64497) == {
        "192.0.2.0/24",
        "2001:db8::/48",
        "2001:db8:1::/48",
    }


def test_tuple_collector(recorder_setup: CollectionsRecorder) -> None:
    recorder = recorder_setup

    stream = io.StringIO(TEST_TUPLE_ADDRESSES)
    collector = TupleCollector(recorder=recorder)
    collector.collect(stream)
    assert recorder.asn_counter == Counter({None: 4, 64496: 6.9, 64497: 1337.3})
    assert recorder.prefix_counter == Counter(
        {None: 4, "192.0.2.0/24": 6.9, "2001:db8:1::/48": 1337.3, "2001:db8::/48": 0}
    )


@pytest.mark.parametrize(
    "line, expected",
    [
        ("garbage", (None, 0.0)),
        (
            "19:05:02.229065 IP 203.0.113.42.443 > 216.90.108.31.62122: tcp 60",
            ("203.0.113.42", 1.0),
        ),
        (
            "19:05:02.229065 IP6 2001:DB8::1.443 > 2001:DB8::2.62122: tcp 60",
            ("2001:DB8::1", 1.0),
        ),
        (
            "14:44:58.275872 IP [total length 52 > length 44] (invalid) 111.88.85.34.46277 > 116.202.120.165.443: tcp 0",  # noqa: E501
            ("111.88.85.34", 1.0),
        ),
    ],
)
def test_parse_tcpdump(line: str, expected: tuple[str | None, float]) -> None:
    c = TcpdumpLineCollector(recorder=lambda x, y: None)  # type: ignore[arg-type]
    assert c.parse(line) == expected


@patch("asncounter.prometheus_client")
def test_prometheus_record(mock_prom: sentinel) -> None:
    # Mock the arguments for the test
    args.no_asn = False
    args.no_prefixes = False
    args.no_resolve_asn = False

    # Initialize PrometheusRecorder with mock Prometheus client
    recorder = PrometheusRecorder(port=None)

    # Mock the lookup_asn return values
    recorder.lookup_asn = MagicMock(return_value="as name")  # type: ignore[method-assign]

    # Record sample data
    recorder.record(asn=64496, prefix="192.0.2.0/24")

    # Ensure labels are called with correct values
    mock_prom.Counter.return_value.labels.assert_any_call(64496, "as name")
    mock_prom.Counter.return_value.labels.assert_any_call(
        "192.0.2.0/24", 64496, "as name"
    )

    # reset mock
    recorder.lookup_asn = MagicMock(return_value=None)  # type: ignore[method-assign]

    # Record sample data with None values
    recorder.record(asn=None, prefix=None)

    # Ensure labels are called with correct values
    mock_prom.Counter.return_value.labels.assert_any_call(None, "")
    mock_prom.Counter.return_value.labels.assert_any_call(None, None, "")


@patch("asncounter.refresh_datfile")
@patch("asncounter.download_asnames")
def test_main(
    mock_refresh_datfile: sentinel,
    mock_download_asnames: sentinel,
    tmp_path: Path,
) -> None:
    """test main path

    This tests the cache lookup logic but not, obviously, the download
    logic, as that requires internet and we don't want to unit test
    upstream pyasn routines here, we trust they do their thing.
    """

    datfile = tmp_path / "ipasn_20250613.200000.dat"
    asnames = tmp_path / "asnames.json"
    mock_refresh_datfile.return_value = str(datfile)

    testfile = tmp_path / "testfile.txt"
    with testfile.open("w") as fp:
        fp.write(TEST_LINE_ADDRESSES)
    # test lookups of a single IP on the commandline, along with an input file
    argv = ["--cache-directory", str(tmp_path), "192.0.2.1"]
    with datfile.open("w") as fp:
        fp.write(SAMPLE_DB)
    with asnames.open("w") as fp:
        fp.write("{}")
    recorder = main(argv)
    assert isinstance(recorder, CollectionsRecorder)
    assert not mock_download_asnames.called
    assert not mock_refresh_datfile.called
    # results match the input file *and* commandline IP
    assert recorder.asn_counter == Counter({64496: 1})
    assert recorder.prefix_counter == Counter({"192.0.2.0/24": 1})

    # test --input still works
    argv = ["--cache-directory", str(tmp_path), "--input", str(testfile)]
    recorder = main(argv)
    assert isinstance(recorder, CollectionsRecorder)
    assert recorder.asn_counter == TEST_LINE_ASN
    assert recorder.prefix_counter == TEST_LINE_PREFIXES
