-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathpool.go
240 lines (194 loc) · 7.52 KB
/
pool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
// Copyright (c) 2019, The Emergent Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package axon
import (
"cogentcore.org/core/math32"
"cogentcore.org/core/vgpu/gosl/slbool"
"github.com/emer/axon/v2/fsfffb"
)
//gosl:hlsl pool
// #include "avgmaxi.hlsl"
//gosl:end pool
//gosl:start pool
// AvgMaxPhases contains the average and maximum values over a Pool of neurons,
// at different time scales within a standard ThetaCycle of updating.
// It is much more efficient on the GPU to just grab everything in one pass at
// the cycle level, and then take snapshots from there.
// All of the cycle level values are updated at the *start* of the cycle
// based on values from the prior cycle -- thus are 1 cycle behind in general.
type AvgMaxPhases struct {
// updated every cycle -- this is the source of all subsequent time scales
Cycle AvgMaxI32 `display:"inline"`
// at the end of the minus phase
Minus AvgMaxI32 `display:"inline"`
// at the end of the plus phase
Plus AvgMaxI32 `display:"inline"`
// at the end of the previous plus phase
Prev AvgMaxI32 `display:"inline"`
}
// CycleToMinus grabs current Cycle values into the Minus phase values
func (am *AvgMaxPhases) CycleToMinus() {
am.Minus = am.Cycle
am.Prev = am.Plus
}
// CycleToPlus grabs current Cycle values into the Plus phase values
func (am *AvgMaxPhases) CycleToPlus() {
am.Plus = am.Cycle
}
// Calc does Calc on Cycle, which is then ready for aggregation again
func (am *AvgMaxPhases) Calc(refIndex int32) {
am.Cycle.Calc(refIndex)
}
// Zero does a full reset on everything -- for InitActs
func (am *AvgMaxPhases) Zero() {
am.Cycle.Zero()
am.Minus.Zero()
am.Plus.Zero()
am.Prev.Zero()
}
// PoolAvgMax contains the average and maximum values over a Pool of neurons
// for different variables of interest, at Cycle, Minus and Plus phase timescales.
// All of the cycle level values are updated at the *start* of the cycle
// based on values from the prior cycle -- thus are 1 cycle behind in general.
type PoolAvgMax struct {
// avg and maximum CaSpkP (continuously updated at roughly 40 msec integration window timescale, ends up capturing potentiation, plus-phase signal) -- this is the primary variable to use for tracking overall pool activity
CaSpkP AvgMaxPhases `edit:"-" display:"inline"`
// avg and maximum CaSpkD longer-term depression / DAPK1 signal in layer
CaSpkD AvgMaxPhases `edit:"-" display:"inline"`
// avg and maximum SpkMax value (based on CaSpkP) -- reflects peak activity at any point across the cycle
SpkMax AvgMaxPhases `edit:"-" display:"inline"`
// avg and maximum Act firing rate value
Act AvgMaxPhases `edit:"-" display:"inline"`
// avg and maximum GeInt integrated running-average excitatory conductance value
GeInt AvgMaxPhases `edit:"-" display:"inline"`
// avg and maximum GiInt integrated running-average inhibitory conductance value
GiInt AvgMaxPhases `edit:"-" display:"inline"`
}
// SetN sets the N for aggregation
func (am *PoolAvgMax) SetN(n int32) {
am.CaSpkP.Cycle.N = n
am.CaSpkD.Cycle.N = n
am.SpkMax.Cycle.N = n
am.Act.Cycle.N = n
am.GeInt.Cycle.N = n
am.GiInt.Cycle.N = n
}
// CycleToMinus grabs current Cycle values into the Minus phase values
func (am *PoolAvgMax) CycleToMinus() {
am.CaSpkP.CycleToMinus()
am.CaSpkD.CycleToMinus()
am.SpkMax.CycleToMinus()
am.Act.CycleToMinus()
am.GeInt.CycleToMinus()
am.GiInt.CycleToMinus()
}
// CycleToPlus grabs current Cycle values into the Plus phase values
func (am *PoolAvgMax) CycleToPlus() {
am.CaSpkP.CycleToPlus()
am.CaSpkD.CycleToPlus()
am.SpkMax.CycleToPlus()
am.Act.CycleToPlus()
am.GeInt.CycleToPlus()
am.GiInt.CycleToPlus()
}
// Init does Init on Cycle vals-- for update start.
// always left init'd so generally unnecessary
func (am *PoolAvgMax) Init() {
am.CaSpkP.Cycle.Init()
am.CaSpkD.Cycle.Init()
am.SpkMax.Cycle.Init()
am.Act.Cycle.Init()
am.GeInt.Cycle.Init()
am.GiInt.Cycle.Init()
}
// Zero does full reset on everything -- for InitActs
func (am *PoolAvgMax) Zero() {
am.CaSpkP.Zero()
am.CaSpkD.Zero()
am.SpkMax.Zero()
am.Act.Zero()
am.GeInt.Zero()
am.GiInt.Zero()
}
// Calc does Calc on Cycle level, and re-inits
func (am *PoolAvgMax) Calc(refIndex int32) {
am.CaSpkP.Calc(refIndex)
am.CaSpkD.Calc(refIndex)
am.SpkMax.Calc(refIndex)
am.Act.Calc(refIndex)
am.GeInt.Calc(refIndex)
am.GiInt.Calc(refIndex)
}
//gosl:end pool
// note: the following is actually being used despite appearing to be
// commented out! it is auto-uncommented when copied to hlsl
// MUST update whenever above UpdateValues code is updated.
//gosl:hlsl pool
/*
// // AtomicUpdatePoolAvgMax provides an atomic update using atomic ints
// // implemented by InterlockedAdd HLSL intrinsic.
// // This is a #define because it doesn't work on arg values --
// // must be directly operating on a RWStorageBuffer entity.
#define AtomicUpdatePoolAvgMax(am, ctx, ni, di) \
AtomicUpdateAvgMaxI32(am.CaSpkP.Cycle, NrnV(ctx, ni, di, CaSpkP)); \
AtomicUpdateAvgMaxI32(am.CaSpkD.Cycle, NrnV(ctx, ni, di, CaSpkD)); \
AtomicUpdateAvgMaxI32(am.SpkMax.Cycle, NrnV(ctx, ni, di, SpkMax)); \
AtomicUpdateAvgMaxI32(am.Act.Cycle, NrnV(ctx, ni, di, Act)); \
AtomicUpdateAvgMaxI32(am.GeInt.Cycle, NrnV(ctx, ni, di, GeInt)); \
AtomicUpdateAvgMaxI32(am.GiInt.Cycle, NrnV(ctx, ni, di, GiInt))
*/
//gosl:end pool
//gosl:start pool
// Pool contains computed values for FS-FFFB inhibition,
// and various other state values for layers
// and pools (unit groups) that can be subject to inhibition
type Pool struct {
// starting and ending (exlusive) layer-wise indexes for the list of neurons in this pool
StIndex, EdIndex uint32 `edit:"-"`
// layer index in global layer list
LayIndex uint32 `display:"-"`
// data parallel index (innermost index per layer)
DataIndex uint32 `display:"-"`
// pool index in global pool list:
PoolIndex uint32 `display:"-"`
// is this a layer-wide pool? if not, it represents a sub-pool of units within a 4D layer
IsLayPool slbool.Bool `edit:"-"`
// for special types where relevant (e.g., MatrixLayer, BGThalLayer), indicates if the pool was gated
Gated slbool.Bool `edit:"-"`
pad uint32
// fast-slow FFFB inhibition values
Inhib fsfffb.Inhib `edit:"-"`
// average and max values for relevant variables in this pool, at different time scales
AvgMax PoolAvgMax
// absolute value of AvgDif differences from actual neuron ActPct relative to TrgAvg
AvgDif AvgMaxI32 `edit:"-" display:"inline"`
}
// Init is callled during InitActs
func (pl *Pool) Init() {
pl.Inhib.Init()
pl.AvgMax.Zero()
pl.AvgMax.SetN(int32(pl.NNeurons()))
pl.AvgDif.N = int32(pl.NNeurons())
pl.AvgDif.Init()
pl.Gated.SetBool(false)
}
// NNeurons returns the number of neurons in the pool: EdIndex - StIndex
func (pl *Pool) NNeurons() int {
return int(pl.EdIndex - pl.StIndex)
}
//gosl:end pool
// AvgMaxUpdate updates the AvgMax values based on current neuron values
func (pl *Pool) AvgMaxUpdate(ctx *Context, ni, di uint32) {
pl.AvgMax.CaSpkP.Cycle.UpdateValue(NrnV(ctx, ni, di, CaSpkP))
pl.AvgMax.CaSpkD.Cycle.UpdateValue(NrnV(ctx, ni, di, CaSpkD))
pl.AvgMax.SpkMax.Cycle.UpdateValue(NrnV(ctx, ni, di, SpkMax))
pl.AvgMax.Act.Cycle.UpdateValue(math32.Abs(NrnV(ctx, ni, di, Act))) // can be neg
pl.AvgMax.GeInt.Cycle.UpdateValue(NrnV(ctx, ni, di, GeInt))
pl.AvgMax.GiInt.Cycle.UpdateValue(NrnV(ctx, ni, di, GiInt))
}
// TestValues returns a map of CaSpkD.Avg, which provides an
// integrated summary of pool activity for testing
func (pl *Pool) TestValues(layKey string, vals map[string]float32) {
vals[layKey+" CaSpkD Avg"] = pl.AvgMax.CaSpkD.Cycle.Avg
}