Skip to content

Commit

Permalink
experimental/graph: add gauge & multi value support
Browse files Browse the repository at this point in the history
  • Loading branch information
jxy committed Jul 17, 2024
1 parent 0703a77 commit b0fdf04
Show file tree
Hide file tree
Showing 6 changed files with 1,134 additions and 86 deletions.
23 changes: 12 additions & 11 deletions src/experimental/graph/core.nim
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@ type
Gfunc* {.acyclic.} = ref object
## Represent an functional operation: [input] -> output,
forward: proc(z: Gvalue)
arg: Gvalue ## extra argument forward/backward uses, must be immutable and can be shared, use getArg/setArg
backward: proc(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue ## create new graph for backprop
runCount: int
name: string
Expand Down Expand Up @@ -49,12 +48,10 @@ var graphDebug* = false

proc newGfunc*(
forward: proc(z: Gvalue) = nil,
arg: Gvalue = nil,
backward: proc(zb: Gvalue, z: Gvalue, i: int, dep: Gvalue): Gvalue = nil,
name: string): Gfunc =
Gfunc(
forward: forward,
arg: arg,
backward: backward,
name: name)

Expand All @@ -68,24 +65,27 @@ method `$`*(x: Gvalue): string {.base.} =
if f != nil:
result &= " " & $f

proc `$`*(x: Gfunc): string =
if x.arg == nil:
x.name & "<" & $x.runCount & ">"
else:
x.name & "<" & $x.runCount & ", " & $x.arg & ">"
proc `$`*(x: Gfunc): string = x.name & "<" & $x.runCount & ">"

proc nodeRepr*(x: Gvalue): string =
let f = x.gfunc
result = $x & " (" & $x.epoch & " " & $x.tag & ")" & "@0X" & strip(toHex(cast[int](x)), trailing = false, chars = {'0'})
if f != nil:
result &= " " & $f & "@0X" & strip(toHex(cast[int](f)), trailing = false, chars = {'0'})

method copyGvalue*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("copyGvalue(" & $x & ")")
method assignGvalue*(z: Gvalue, x: Gvalue) {.base.} =
method newOneOf*(x: Gvalue): Gvalue {.base.} = raiseErrorBaseMethod("newOneOf(" & $x & ")")
method valCopy*(z: Gvalue, x: Gvalue) {.base.} = raiseErrorBaseMethod("valCopy(" & $z & "," & $x & ")")

proc assignGvalue(z: Gvalue, x: Gvalue) =
z.tag = x.tag
z.inputs = x.inputs
z.gfunc = x.gfunc
z.epoch = x.epoch
z.valCopy x

proc copyGvalue(x: Gvalue): Gvalue =
result = newOneOf x
result.assignGvalue x

let identPlaceholderGFunc = newGfunc(name = "identPlaceholder")
proc identPlaceholder(x: Gvalue): Gvalue =
Expand Down Expand Up @@ -151,7 +151,7 @@ proc updated*(x: Gvalue) =
inc epoch
x.epoch = epoch

proc eval*(v: Gvalue) =
proc eval*(v: Gvalue): Gvalue {.discardable.} =
proc r(x: Gvalue) =
if gtVisited in x.tag:
return
Expand All @@ -176,6 +176,7 @@ proc eval*(v: Gvalue) =
raiseError("inputs.len: " & $x.inputs.len & ", but no forward function defined for:\n" & x.nodeRepr)
v.r
v.tagClearVisited
v

type
Grad = object
Expand Down
Loading

0 comments on commit b0fdf04

Please sign in to comment.