Skip to content

Commit

Permalink
hip looper functions
Browse files Browse the repository at this point in the history
  • Loading branch information
rcoreilly committed Oct 16, 2024
1 parent f1b85eb commit 4224e36
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 29 deletions.
50 changes: 25 additions & 25 deletions leabra/enumgen.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

63 changes: 60 additions & 3 deletions leabra/hip.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,12 @@
package leabra

import (
"fmt"

"cogentcore.org/core/base/errors"
"cogentcore.org/core/math32"
"github.com/emer/emergent/v2/etime"
"github.com/emer/emergent/v2/looper"
)

// Contrastive Hebbian Learning (CHL) parameters
Expand Down Expand Up @@ -152,9 +157,6 @@ func (pt *Path) EcCa1Defaults() {
// DWt computes the weight change (learning) -- on sending pathways
// Delta version
func (pt *Path) DWtEcCa1() {
if !pt.Learn.Learn {
return
}
slay := pt.Send
rlay := pt.Recv
for si := range slay.Neurons {
Expand Down Expand Up @@ -202,3 +204,58 @@ func (pt *Path) DWtEcCa1() {
}
}
}

// ConfigLoopsHip configures the hippocampal looper and should be included in ConfigLoops
// in model to make sure hip loops is configured correctly.
// see hip.go for an instance of implementation of this function.
func (net *Network) ConfigLoopsHip(ctx *Context, man *looper.Manager) {
var tmpValues []float32
ecout := net.LayerByName("ECout")
ecin := net.LayerByName("ECin")
ca1 := net.LayerByName("CA1")
ca3 := net.LayerByName("CA3")
ca1FromECin := errors.Log1(ca1.RecvPathBySendName("ECin")).(*Path)
ca1FromCa3 := errors.Log1(ca1.RecvPathBySendName("CA3")).(*Path)
ca3FromDg := errors.Log1(ca3.RecvPathBySendName("DG")).(*Path)

dgPjScale := ca3FromDg.WtScale.Rel

// configure events -- note that events are shared between Train, Test
// so only need to do it once on Train
mode := etime.Train
stack := man.Stacks[mode]
cyc, _ := stack.Loops[etime.Cycle]
minusStart := cyc.EventByName("MinusPhase") // cycle 0
minusStart.OnEvent.Add("HipMinusPhase:Start", func() {
ca1FromECin.WtScale.Abs = 1
ca1FromCa3.WtScale.Abs = 0
ca3FromDg.WtScale.Rel = 0
net.GScaleFromAvgAct()
net.InitGInc()
})
quarter1 := cyc.EventByName("Quarter1")
quarter1.OnEvent.Add("Hip:Quarter1", func() {
ca1FromECin.WtScale.Abs = 0
ca1FromCa3.WtScale.Abs = 1
if ctx.Mode == etime.Test {
ca3FromDg.WtScale.Rel = 1 // weaker
fmt.Println("test:, rel = 1")
} else {
ca3FromDg.WtScale.Rel = dgPjScale
fmt.Println("train, rel:", dgPjScale)
}
net.GScaleFromAvgAct()
net.InitGInc()
})
plus := cyc.EventByName("PlusPhase")
plus.OnEvent.InsertBefore("MinusPhase:End", "HipPlusPhase:Start", func() {
ca1FromECin.WtScale.Abs = 1
ca1FromCa3.WtScale.Abs = 0
if ctx.Mode == etime.Train {
ecin.UnitValues(&tmpValues, "Act", 0)
ecout.ApplyExt1D32(tmpValues)
}
net.GScaleFromAvgAct()
net.InitGInc()
})
}
Loading

0 comments on commit 4224e36

Please sign in to comment.