using MetagraphOptimization
using Random

function test_known_graph(name::String, n, fusion_test = true)
    @testset "Test $name Graph ($n)" begin
        graph = parse_dag(joinpath(@__DIR__, "..", "input", "$name.txt"), ABCModel())
        props = get_properties(graph)

        if (fusion_test)
            test_node_fusion(graph)
        end
        test_random_walk(graph, n)
    end
end

function test_node_fusion(g::DAG)
    @testset "Test Node Fusion" begin
        props = get_properties(g)

        options = get_operations(g)

        nodes_number = length(g.nodes)
        data = props.data
        compute_effort = props.computeEffort

        while !isempty(options.nodeFusions)
            fusion = first(options.nodeFusions)

            @test typeof(fusion) <: NodeFusion

            push_operation!(g, fusion)

            props = get_properties(g)
            @test props.data < data
            @test props.computeEffort == compute_effort

            nodes_number = length(g.nodes)
            data = props.data
            compute_effort = props.computeEffort

            options = get_operations(g)
        end
    end
end

function test_random_walk(g::DAG, n::Int64)
    @testset "Test Random Walk ($n)" begin
        # the purpose here is to do "random" operations and reverse them again and validate that the graph stays the same and doesn't diverge
        reset_graph!(g)

        @test is_valid(g)

        properties = get_properties(g)

        for i in 1:n
            # choose push or pop
            if rand(Bool)
                # push
                opt = get_operations(g)

                # choose one of fuse/split/reduce
                option = rand(1:3)
                if option == 1 && !isempty(opt.nodeFusions)
                    push_operation!(g, rand(collect(opt.nodeFusions)))
                elseif option == 2 && !isempty(opt.nodeReductions)
                    push_operation!(g, rand(collect(opt.nodeReductions)))
                elseif option == 3 && !isempty(opt.nodeSplits)
                    push_operation!(g, rand(collect(opt.nodeSplits)))
                else
                    i = i - 1
                end
            else
                # pop
                if (can_pop(g))
                    pop_operation!(g)
                else
                    i = i - 1
                end
            end
        end

        reset_graph!(g)

        @test is_valid(g)

        @test properties == get_properties(g)
    end
end

Random.seed!(0)

test_known_graph("AB->AB", 10000)
test_known_graph("AB->ABBB", 10000)
test_known_graph("AB->ABBBBB", 1000, false)