import MetagraphOptimization.insert_node!
import MetagraphOptimization.insert_edge!
import MetagraphOptimization.make_node

@testset "Unit Tests Node Reduction" begin
    graph = MetagraphOptimization.DAG()

    d_exit = insert_node!(graph, make_node(DataTask(10)), false)

    s0 = insert_node!(graph, make_node(ComputeTaskS2()), false)

    ED = insert_node!(graph, make_node(DataTask(3)), false)
    FD = insert_node!(graph, make_node(DataTask(3)), false)

    EC = insert_node!(graph, make_node(ComputeTaskV()), false)
    FC = insert_node!(graph, make_node(ComputeTaskV()), false)

    A1D = insert_node!(graph, make_node(DataTask(4)), false)
    B1D_1 = insert_node!(graph, make_node(DataTask(4)), false)
    B1D_2 = insert_node!(graph, make_node(DataTask(4)), false)
    C1D = insert_node!(graph, make_node(DataTask(4)), false)

    A1C = insert_node!(graph, make_node(ComputeTaskU()), false)
    B1C_1 = insert_node!(graph, make_node(ComputeTaskU()), false)
    B1C_2 = insert_node!(graph, make_node(ComputeTaskU()), false)
    C1C = insert_node!(graph, make_node(ComputeTaskU()), false)

    AD = insert_node!(graph, make_node(DataTask(5)), false)
    BD = insert_node!(graph, make_node(DataTask(5)), false)
    CD = insert_node!(graph, make_node(DataTask(5)), false)

    insert_edge!(graph, s0, d_exit, false)
    insert_edge!(graph, ED, s0, false)
    insert_edge!(graph, FD, s0, false)
    insert_edge!(graph, EC, ED, false)
    insert_edge!(graph, FC, FD, false)

    insert_edge!(graph, A1D, EC, false)
    insert_edge!(graph, B1D_1, EC, false)

    insert_edge!(graph, B1D_2, FC, false)
    insert_edge!(graph, C1D, FC, false)

    insert_edge!(graph, A1C, A1D, false)
    insert_edge!(graph, B1C_1, B1D_1, false)
    insert_edge!(graph, B1C_2, B1D_2, false)
    insert_edge!(graph, C1C, C1D, false)

    insert_edge!(graph, AD, A1C, false)
    insert_edge!(graph, BD, B1C_1, false)
    insert_edge!(graph, BD, B1C_2, false)
    insert_edge!(graph, CD, C1C, false)

    @test is_exit_node(d_exit)
    @test is_entry_node(AD)
    @test is_entry_node(BD)
    @test is_entry_node(CD)

    opt = get_operations(graph)

    @test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)

    #println("Initial State:\n", opt)

    nr = first(opt.nodeReductions)
    @test Set(nr.input) == Set([B1C_1, B1C_2])
    push_operation!(graph, nr)
    opt = get_operations(graph)

    @test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
    #println("After 1 Node Reduction:\n", opt)

    nr = first(opt.nodeReductions)
    @test Set(nr.input) == Set([B1D_1, B1D_2])
    push_operation!(graph, nr)
    opt = get_operations(graph)

    @test length(opt) == (nodeFusions = 4, nodeReductions = 0, nodeSplits = 1)
    #println("After 2 Node Reductions:\n", opt)

    pop_operation!(graph)

    opt = get_operations(graph)
    @test length(opt) == (nodeFusions = 4, nodeReductions = 1, nodeSplits = 1)
    #println("After reverting the second Node Reduction:\n", opt)

    reset_graph!(graph)

    opt = get_operations(graph)
    @test length(opt) == (nodeFusions = 6, nodeReductions = 1, nodeSplits = 1)
    #println("After reverting to the initial state:\n", opt)
end
println("Node Reduction Unit Tests Complete!")