Skip to content

Commit

Permalink
Merge pull request #58 from SpM-lab/terasaki/improve-refine_grid-impl…
Browse files Browse the repository at this point in the history
…ementation

Improve the `SparseIR.refine_grid` function
  • Loading branch information
SamuelBadr authored Dec 12, 2024
2 parents 8fbff17 + eee06e8 commit 3143c8d
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 2 deletions.
5 changes: 3 additions & 2 deletions src/_roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -39,10 +39,11 @@ end
closeenough(a::T, b::T, ϵ) where {T<:AbstractFloat} = isapprox(a, b; rtol=0, atol=ϵ)
closeenough(a::T, b::T, _) where {T<:Integer} = a == b

function refine_grid(grid, ::Val{α}) where {α}
function refine_grid(grid::Vector{T}, ::Val{α}) where {T, α}
isempty(grid) && return float(T)[]
n = length(grid)
newn = α * (n - 1) + 1
newgrid = Vector{eltype(grid)}(undef, newn)
newgrid = Vector{float(T)}(undef, newn)

@inbounds for i in 1:(n - 1)
xb = grid[i]
Expand Down
77 changes: 77 additions & 0 deletions test/_roots.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,83 @@ using Test
using SparseIR

@testset "_roots.jl" begin
@testset "refine_grid" begin
@testset "Basic refinement" begin
# Test with α = 2
grid = [0.0, 1.0, 2.0]
refined = @inferred SparseIR.refine_grid(grid, Val(2))
@test length(refined) == 5 # α * (n-1) + 1 = 2 * (3-1) + 1 = 5
@test refined [0.0, 0.5, 1.0, 1.5, 2.0]

# Test with α = 3
refined3 = SparseIR.refine_grid(grid, Val(3))
@test length(refined3) == 7 # α * (n-1) + 1 = 3 * (3-1) + 1 = 7
@test refined3 [0.0, 1/3, 2/3, 1.0, 4/3, 5/3, 2.0]

# Test with α = 4
refined4 = SparseIR.refine_grid(grid, Val(4))
@test length(refined4) == 9 # α * (n-1) + 1 = 4 * (3-1) + 1 = 9
@test refined4 [0.0, 0.25, 0.5, 0.75, 1.0, 1.25, 1.5, 1.75, 2.0]
end

#=
Currently the SparseIR.refine_grid function is used only for the
SparseIR.find_all function, which should accept only Float64 grids, but
just in case it is a good idea to test that it works for other types.
=#
@testset "Type stability" begin
# Integer grid
int_grid = [0, 1, 2]
refined = @inferred SparseIR.refine_grid(int_grid, Val(2))
@test eltype(refined) === Float64
@test refined == [0.0, 0.5, 1.0, 1.5, 2.0]

# Float32 grid
f32_grid = Float32[0, 1, 2]
refined_f32 = @inferred SparseIR.refine_grid(f32_grid, Val(2))
@test eltype(refined_f32) === Float32
@test refined_f32 Float32[0, 0.5, 1, 1.5, 2]
end

@testset "Edge cases" begin
# Single interval
single_interval = [0.0, 1.0]
refined_single = SparseIR.refine_grid(single_interval, Val(4))
@test length(refined_single) == 5 # α * (2-1) + 1 = 4 * 1 + 1 = 5
@test refined_single [0.0, 0.25, 0.5, 0.75, 1.0]

# Empty grid
empty_grid = Float64[]
@test isempty(SparseIR.refine_grid(empty_grid, Val(2)))
# Empty grid
empty_grid = Int[]
out_grid = SparseIR.refine_grid(empty_grid, Val(2))
@test isempty(out_grid)
@test eltype(out_grid) === Float64
# Single point
single_point = [1.0]
@test SparseIR.refine_grid(single_point, Val(2)) == [1.0]
end

@testset "Uneven spacing" begin
# Test with unevenly spaced grid
uneven = [0.0, 1.0, 10.0]
refined_uneven = SparseIR.refine_grid(uneven, Val(2))
@test length(refined_uneven) == 5
@test refined_uneven[1:3] [0.0, 0.5, 1.0] # First interval
@test refined_uneven[3:5] [1.0, 5.5, 10.0] # Second interval
end

@testset "Preservation of endpoints" begin
grid = [-1.0, 0.0, 1.0]
for α in [2, 3, 4]
refined = SparseIR.refine_grid(grid, Val(α))
@test first(refined) == first(grid)
@test last(refined) == last(grid)
end
end
end

@testset "discrete_extrema" begin
nonnegative = collect(0:8)
symmetric = collect(-8:8)
Expand Down

0 comments on commit 3143c8d

Please sign in to comment.