Skip to content

Commit

Permalink
Fix augmented basis
Browse files Browse the repository at this point in the history
  • Loading branch information
SamuelBadr committed Aug 17, 2022
1 parent e8a14d5 commit 0f25851
Show file tree
Hide file tree
Showing 24 changed files with 319 additions and 319 deletions.
8 changes: 4 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -17,11 +17,11 @@ pkg> add SparseIR

```julia
using SparseIR
beta = 10.0
β = 10.0
ωmax = 1.0
eps = 1e-7
basis_f = FiniteTempBasis(fermion, beta, ωmax, eps)
basis_b = FiniteTempBasis(boson, beta, ωmax, eps)
ε = 1e-7
basis_f = FiniteTempBasis(Fermionic(), β, ωmax, ε)
basis_b = FiniteTempBasis(Bosonic(), β, ωmax, ε)
```

## Tutorial and sample codes
Expand Down
2 changes: 1 addition & 1 deletion docs/src/guide.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ for us, where 80.0 is the value of the scale parameter ``\Lambda = \beta\omega_\

### SVE

Central is the _singular value expansion_'s (SVE) computation, which is handled by the function `compute_sve`:
Central is the _singular value expansion_'s (SVE) computation, which is handled by the function `SVEResult`:
Its purpose is constructing the decomposition
```math
\begin{equation}\label{SVE}
Expand Down
5 changes: 1 addition & 4 deletions src/SparseIR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ Intermediate representation (IR) for many-body propagators.
"""
module SparseIR

export fermion, boson
export Fermionic, Bosonic
export MatsubaraFreq, BosonicFreq, FermionicFreq, pioverbeta
export FiniteTempBasis
export SparsePoleRepresentation, to_IR, from_IR
Expand Down Expand Up @@ -32,9 +32,6 @@ include("_specfuncs.jl")
using ._LinAlg: tsvd

include("freq.jl")
const boson = Bosonic()
const fermion = Fermionic()

include("abstract.jl")
include("svd.jl")
include("gauss.jl")
Expand Down
4 changes: 2 additions & 2 deletions src/_linalg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ function rrqr!(A::AbstractMatrix{T}; rtol=eps(T)) where {T<:AbstractFloat}
taus = Vector{T}(undef, k)
swapcol = Vector{T}(undef, m)

xnorms = map(norm, eachcol(A))
xnorms = norm.(eachcol(A))
pnorms = copy(xnorms)
sqrteps = sqrt(eps(T))

Expand Down Expand Up @@ -296,7 +296,7 @@ function svd_jacobi!(U::AbstractMatrix{T}; rtol=eps(T), maxiter=20) where {T}
offd < rtol * Unorm && break
end

s = map(norm, eachcol(U))
s = norm.(eachcol(U))
U ./= transpose(s)
return SVD(U, s, VT)
end
Expand Down
9 changes: 4 additions & 5 deletions src/abstract.jl
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,10 @@ as follows:
where `basis.uhat[l]` is now the Fourier transform of the basis function.
"""
abstract type AbstractBasis end
abstract type AbstractBasis{S <: Statistics} end

Base.size(::AbstractBasis) = error("unimplemented")
Base.broadcastable(b::AbstractBasis) = Ref(b)

"""
Base.getindex(basis::AbstractBasis, I)
Expand All @@ -30,10 +31,8 @@ Return basis functions/singular values for given index/indices.
This can be used to truncate the basis to the `n` most significant
singular values: `basis[1:3]`.
"""
Base.getindex(::AbstractBasis, I) = error("unimplemented")

Base.getindex(::AbstractBasis, _) = error("unimplemented")
Base.firstindex(::AbstractBasis) = 1

Base.length(basis::AbstractBasis) = length(basis.s)

"""
Expand Down Expand Up @@ -76,7 +75,7 @@ default_matsubara_sampling_points(::AbstractBasis) = error("unimplemented")
Quantum statistic (Statistics instance, Fermionic() or Bosonic()).
"""
statistics(basis::AbstractBasis) = basis.statistics
statistics(::AbstractBasis{S}) where {S<:Statistics} = S()

"""
Λ(basis::AbstractBasis)
Expand Down
151 changes: 78 additions & 73 deletions src/augment.jl
Original file line number Diff line number Diff line change
@@ -1,3 +1,19 @@
"""
AbstractAugmentation
Scalar function in imaginary time/frequency.
This represents a single function in imaginary time and frequency, together
with some auxiliary methods that make it suitable for augmenting a basis.
See also: [`AugmentedBasis`](@ref)
"""
abstract type AbstractAugmentation <: Function end

const AugmentationTuple = Tuple{Vararg{<:AbstractAugmentation}}

create(aug::AbstractAugmentation, ::AbstractBasis) = aug

"""
AugmentedBasis <: AbstractBasis
Expand All @@ -21,41 +37,42 @@ that serves as a base for multi-point functions [^shinaoka2018].
Bases augmented with `TauConst` and `TauLinear` tend to be poorly
conditioned. Care must be taken while fitting and compactness should
be enforced if possible to regularize the problem.
While vertex bases, i.e. bases augmented with `MatsubaraConst`, stay
reasonably well-conditioned, it is still good practice to treat the
Hartree--Fock term separately rather than including it in the basis,
if possible.
See also: [`MatsubaraConst`](@ref) for vertex basis [^wallerberger2021], [`TauConst`](@ref), [`TauLinear`](@ref) for multi-point [^shinaoka2018]
See also: [`MatsubaraConst`](@ref) for vertex basis [^wallerberger2021],
[`TauConst`](@ref),
[`TauLinear`](@ref) for multi-point [^shinaoka2018]
[^wallerberger2021]: https://doi.org/10.1103/PhysRevResearch.3.033168
[^shinaoka2018]: https://doi.org/10.1103/PhysRevB.97.205111
"""
struct AugmentedBasis{B<:AbstractBasis} <: AbstractBasis
basis::B
augmentations::Any
u::Any
uhat::Any
struct AugmentedBasis{S<:Statistics,B<:FiniteTempBasis{S},A<:AugmentationTuple,F,FHAT} <: AbstractBasis{S}
basis :: B
augmentations :: A
u :: F
uhat :: FHAT
end

function AugmentedBasis(basis::AbstractBasis, augmentations...)
augmentations = Tuple(augmentation_factory(basis, augmentations...))
augmentations = Tuple(create(aug, basis) for aug in augmentations)
u = AugmentedTauFunction(basis.u, augmentations)
= AugmentedMatsubaraFunction(basis.uhat, [n -> hat(aug, n) for aug in augmentations])
= AugmentedMatsubaraFunction(basis.uhat, augmentations)
return AugmentedBasis(basis, augmentations, u, û)
end

statistics(basis::AugmentedBasis) = statistics(basis.basis)
naug(basis::AugmentedBasis) = length(basis.augmentations)

function getindex(basis::AugmentedBasis, index)
stop = range_to_size(index)
function getindex(basis::AugmentedBasis, index::AbstractRange)
stop = range_to_length(index)
stop > naug(basis) || error("Cannot truncate to only augmentation.")
return AugmentedBasis(basis.basis[begin:(stop - naug(basis))], basis.augmentations)
end

Base.size(basis::AugmentedBasis) = (length(basis), )
Base.size(basis::AugmentedBasis) = (length(basis),)
Base.length(basis::AugmentedBasis) = naug(basis) + length(basis.basis)
significance(basis::AugmentedBasis) = significance(basis.basis)
accuracy(basis::AugmentedBasis) = accuracy(basis.basis)
Expand All @@ -77,51 +94,50 @@ function iswellconditioned(basis::AugmentedBasis)
return wbasis && waug
end


############################################################################################
# Augmented Functions #
############################################################################################

abstract type AbstractAugmentedFunction <: Function end

struct AugmentedFunction <: AbstractAugmentedFunction
fbasis::Any
faug::Any
struct AugmentedFunction{FB,FA} <: AbstractAugmentedFunction
fbasis :: FB
faug :: FA
end

augmentedfunction(a::AugmentedFunction) = a
augmentedfunction(a::AugmentedFunction) = a

fbasis(a::AbstractAugmentedFunction) = augmentedfunction(a).fbasis
faug(a::AbstractAugmentedFunction) = augmentedfunction(a).faug
naug(a::AbstractAugmentedFunction) = length(faug(a))

Base.length(a::AbstractAugmentedFunction) = naug(a) + length(fbasis(a))
Base.size(a::AbstractAugmentedFunction) = (length(a), )
Base.size(a::AbstractAugmentedFunction) = (length(a),)

function (a::AbstractAugmentedFunction)(x)
fbasis_x = fbasis(a)(x)
faug_x = [faug_l(x) for faug_l in faug(a)]
return fbasis_x .+ faug_x
return vcat(faug_x, fbasis_x)
end
function (a::AbstractAugmentedFunction)(x::AbstractArray)
fbasis_x = fbasis(a)(x)
faug_x = (faug_l.(x) for faug_l in faug(a))
return sum(fbasis_x .+ transpose(faug_xi) for faug_xi in faug_x)
faug_x = [faug_l.(transpose(x)) for faug_l in faug(a)]
return vcat(faug_x..., fbasis_x)
end

function Base.getindex(a::AbstractAugmentedFunction, r::AbstractRange)
stop = range_to_size(r)
stop = range_to_length(r)
stop > naug(a) || error("Don't truncate to only augmentation")
return AugmentedFunction(fbasis(a)[begin:(stop-naug(a))], faug(a))
return AugmentedFunction(fbasis(a)[begin:(stop - naug(a))], faug(a))
end
function Base.getindex(a::AbstractAugmentedFunction, l::Integer)
if l < naug(a)
return faug(a)[l]
else
return fbasis(a)[l-naug(a)]
end
return l naug(a) ? faug(a)[l] : fbasis(a)[l - naug(a)]
end


struct AugmentedTauFunction <: AbstractAugmentedFunction
a::AugmentedFunction
### AugmentedTauFunction

struct AugmentedTauFunction{FB,FA} <: AbstractAugmentedFunction
a::AugmentedFunction{FB,FA}
end

augmentedfunction(aτ::AugmentedTauFunction) =.a
Expand All @@ -137,19 +153,25 @@ function deriv(aτ::AugmentedTauFunction, n=1)
return AugmentedTauFunction(dbasis, daug)
end

### AugmentedMatsubaraFunction

struct AugmentedMatsubaraFunction <: AbstractAugmentedFunction
a::AugmentedFunction
struct AugmentedMatsubaraFunction{FB,FA} <: AbstractAugmentedFunction
a::AugmentedFunction{FB,FA}
end

augmentedfunction(amat::AugmentedMatsubaraFunction) = amat.a

AugmentedMatsubaraFunction(fbasis, faug) = AugmentedMatsubaraFunction(AugmentedFunction(fbasis, faug))
function AugmentedMatsubaraFunction(fbasis, faug)
AugmentedMatsubaraFunction(AugmentedFunction(fbasis, faug))
end

zeta(amat::AugmentedMatsubaraFunction) = zeta(fbasis(amat))

############################################################################################
# Augmentations #
############################################################################################

abstract type AbstractAugmentation end
### TauConst

struct TauConst <: AbstractAugmentation
β::Float64
Expand All @@ -159,60 +181,56 @@ struct TauConst <: AbstractAugmentation
end
end

function create(::Type{TauConst}, basis::AbstractBasis)
statistics(basis)::Bosonic
return TauConst(β(basis))
end
create(::Type{TauConst}, basis::AbstractBasis{Bosonic}) = TauConst(β(basis))

function (aug::TauConst)(τ)
0 τ aug.β || throw(DomainError(τ, "τ must be in [0, β]."))
return 1 / (aug.β)
end
function (aug::TauConst)(n::BosonicFreq)
iszero(n) || return 0.0
return (aug.β)
end
(::TauConst)(::FermionicFreq) = error("TauConst is not a Fermionic basis.")

function deriv(aug::TauConst, n=1)
iszero(n) && return aug
!iszero(n) || return aug
return τ -> 0.0
end

function hat(aug::TauConst, n::BosonicFreq)
iszero(n) || return 0.0
return (aug.β)
end

### TauLinear

struct TauLinear <: AbstractAugmentation
β::Float64
norm::Float64
β :: Float64
norm :: Float64
function TauLinear(β)
β > 0 || throw(DomainError(β, "Temperature must be positive."))
norm = (3 / β)
return new(β, norm)
end
end

function create(::Type{TauLinear}, basis::AbstractBasis)
statistics(basis)::Bosonic
return TauLinear(β(basis))
end
create(::Type{TauLinear}, basis::AbstractBasis{Bosonic}) = TauLinear(β(basis))

function (aug::TauLinear)(τ)
0 τ aug.β || throw(DomainError(τ, "τ must be in [0, β]."))
x = 2 / aug.β * τ - 1
return aug.norm * x
end
function (aug::TauLinear)(n::BosonicFreq)
inv_w = value(n, aug.β)
inv_w = iszero(n) ? inv_w : 1 / inv_w
return aug.norm * 2 / im * inv_w
end
(::TauLinear)(::FermionicFreq) = error("TauLinear is not a Fermionic basis.")

function deriv(aug::TauLinear, n=1)
iszero(n) && return aug
isone(n) && return τ -> aug.norm * 2 / aug.β
return τ -> 0.0
end

function hat(aug::TauLinear, n::BosonicFreq)
inv_w = value(n, aug.β)
inv_w = iszero(n) ? inv_w : 1 / inv_w
return aug.norm * 2/im * inv_w
end

### MatsubaraConst

struct MatsubaraConst <: AbstractAugmentation
β::Float64
Expand All @@ -228,21 +246,8 @@ function (aug::MatsubaraConst)(τ)
0 τ aug.β || throw(DomainError(τ, "τ must be in [0, β]."))
return NaN
end

deriv(aug::MatsubaraConst, n=1) = aug

function hat(::MatsubaraConst, ::MatsubaraFreq)
function (::MatsubaraConst)(::MatsubaraFreq)
return 1.0
end


augmentation_factory(basis::AbstractBasis, augs...) =
Iterators.map(augs) do aug
if aug isa AbstractAugmentation
return aug
else
return create(aug, basis)
end
end

create(aug, )
deriv(aug::MatsubaraConst, _=1) = aug
Loading

0 comments on commit 0f25851

Please sign in to comment.