import io
import os
import pathlib
import tempfile
import unittest

from google.protobuf.message import DecodeError

import gtirb

IR_FILE = tempfile.mktemp(suffix=".gtirb")


class IRTest(unittest.TestCase):
    def __init__(self, *args, **kwargs):
        super().__init__(*args, **kwargs)

        ir = gtirb.IR()
        m = gtirb.Module(
            binary_path="binary_path",
            file_format=gtirb.Module.FileFormat.RAW,
            isa=gtirb.Module.ISA.ValidButUnsupported,
            name="name",
            preferred_addr=1,
            rebase_delta=2,
            ir=ir,
        )
        s = gtirb.Section(
            name="name",
            flags=(
                gtirb.Section.Flag.Executable,
                gtirb.Section.Flag.Readable,
                gtirb.Section.Flag.Loaded,
                gtirb.Section.Flag.Initialized,
            ),
            module=m,
        )
        bi = gtirb.ByteInterval(
            address=0, size=10, contents=b"abcd", section=s
        )
        cb = gtirb.CodeBlock(
            size=4,
            offset=0,
            decode_mode=gtirb.CodeBlock.DecodeMode.Thumb,
            byte_interval=bi,
        )
        _ = gtirb.DataBlock(size=6, offset=4, byte_interval=bi)
        sym = gtirb.Symbol(name="name", payload=cb, module=m)
        sac = gtirb.SymAddrConst(
            0, sym, {gtirb.SymbolicExpression.Attribute.G1}
        )
        bi.symbolic_expressions[2] = sac
        p = gtirb.ProxyBlock(module=m)
        ir.cfg.add(
            gtirb.Edge(
                cb,
                p,
                gtirb.Edge.Label(
                    type=gtirb.Edge.Type.Branch, conditional=False, direct=True
                ),
            )
        )
        ir.cfg.add(gtirb.Edge(p, p))
        m.aux_data["key"] = gtirb.AuxData(gtirb.Offset(s, 777), "Offset")
        ir.aux_data["key"] = gtirb.AuxData("value", "string")

        self.ir = ir

    def setUp(self):
        self.ir.save_protobuf(IR_FILE)

    def tearDown(self):
        os.remove(IR_FILE)

    def test_ir_protobuf_load(self):
        new_ir = gtirb.IR.load_protobuf(IR_FILE)
        self.assertTrue(self.ir.deep_eq(new_ir))
        self.assertNotEqual(
            self.ir.modules[0].aux_data["key"].data,
            new_ir.modules[0].aux_data["key"].data,
        )

    def test_load_pathlib(self):
        """
        Ensure `load_protobuf` and `save_protobuf` support path-like objects
        """
        ir_path = pathlib.Path(IR_FILE)
        new_ir = gtirb.IR.load_protobuf(ir_path)
        self.assertTrue(self.ir.deep_eq(new_ir))
        self.assertNotEqual(
            self.ir.modules[0].aux_data["key"].data,
            new_ir.modules[0].aux_data["key"].data,
        )
        new_ir.save_protobuf(ir_path)


class NotGTIRBTest(unittest.TestCase):
    def test(self):
        file_content = io.BytesIO(b"JUNK")
        with self.assertRaises(Exception) as context:
            gtirb.IR.load_protobuf_file(file_content)

        self.assertEqual(
            "File missing GTIRB magic - not a GTIRB file?",
            str(context.exception),
        )


class BadVersionTest(unittest.TestCase):
    def test(self):
        file_content = io.BytesIO(b"GTIRB\x00\x00\xFF")
        with self.assertRaises(Exception) as context:
            gtirb.IR.load_protobuf_file(file_content)

        self.assertTrue(
            "Attempt to decode IR of version" in str(context.exception)
        )


class BadProtobufTest(unittest.TestCase):
    def test(self):
        bytes = b"GTIRB\x00\x00"
        bytes += gtirb.version.PROTOBUF_VERSION.to_bytes(1, byteorder="little")
        bytes += b"JUNK"
        file_content = io.BytesIO(bytes)
        with self.assertRaises(DecodeError) as context:
            gtirb.IR.load_protobuf_file(file_content)


class IRMethodTests(unittest.TestCase):
    def test_modules_named(self):
        """
        Test the IR.modules_named method
        """
        ir = gtirb.IR()

        def add_module(name: str):
            return gtirb.Module(
                file_format=gtirb.Module.FileFormat.RAW,
                isa=gtirb.Module.ISA.ValidButUnsupported,
                name=name,
                ir=ir,
            )

        m1 = add_module("m1")
        m2 = add_module("m2")
        m3_a = add_module("m3")
        m3_b = add_module("m3")

        self.assertEqual(next(ir.modules_named("m1")), m1)
        self.assertEqual(next(ir.modules_named("m2")), m2)
        m3s = list(ir.modules_named("m3"))
        self.assertEqual(len(m3s), 2)
        self.assertIn(m3_a, m3s)
        self.assertIn(m3_b, m3s)


if __name__ == "__main__":
    unittest.main()
