Skip to content

Commit

Permalink
faster version
Browse files Browse the repository at this point in the history
  • Loading branch information
srush committed Aug 10, 2024
1 parent 08599a6 commit c00c4f2
Show file tree
Hide file tree
Showing 26 changed files with 125 additions and 136 deletions.
8 changes: 4 additions & 4 deletions Makefile
Original file line number Diff line number Diff line change
Expand Up @@ -287,12 +287,12 @@ images: $(IMAGES)

# Rule to process each image
$(IMAGES):
python $(EXAMPLES_DIR)/$@.py
jupytext --execute --run-path . --set-kernel chalk --to ipynb -o $(EXAMPLES_DIR)/output/$@.ipynb $(EXAMPLES_DIR)/$@.py
jupyter nbconvert --to html $(EXAMPLES_DIR)/output/$@.ipynb
CHALK_CHECK=1 python $(EXAMPLES_DIR)/$@.py
# jupytext --execute --run-path . --set-kernel chalk --to ipynb -o $(EXAMPLES_DIR)/output/[email protected] $(EXAMPLES_DIR)/[email protected]
# jupyter nbconvert --to html $(EXAMPLES_DIR)/output/[email protected]

# List of images to be generated
VISTESTS := alignment arc broadcast combinators envelope names path rendering shapes styles subdiagram trails transformations text
VISTESTS := alignment arc broadcast combinators envelope names path rendering shapes styles subdiagram trails transformations trace text

VT_DIR := api

Expand Down
4 changes: 3 additions & 1 deletion api/arc.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
from chalk import *

print(arc(1, 0, 90))
arc(1, 0, 270)



square(2).fill_color("orange") + arc(1, 0, 90)
3 changes: 1 addition & 2 deletions chalk/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@
from chalk.arrowheads import *
from chalk.export import *

if eval(os.environ.get("CHALK_CHECK", "1")):
if eval(os.environ.get("CHALK_CHECK", "0")):
assert hook is not None
hook.uninstall()

Expand All @@ -49,7 +49,6 @@
chalk.core.ApplyStyle,
chalk.core.ComposeAxis,
chalk.envelope.EnvDistance,
chalk.trace.TraceDistances,
chalk.style.StyleHolder,
chalk.trail.Trail,
chalk.path.Path,
Expand Down
7 changes: 5 additions & 2 deletions chalk/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -112,11 +112,14 @@ def compose(
if isinstance(self, Empty):
return other
elif isinstance(self, Compose) and isinstance(other, Compose):
return Compose(envelope, self.diagrams + other.diagrams)
if self.envelope is None and other.envelope is None:
return Compose(envelope, self.diagrams + other.diagrams)
else:
return Compose(envelope, (self, other))
elif isinstance(other, Empty) and not isinstance(self, Compose):
return Compose(envelope, (self,))

elif isinstance(other, Compose):
elif isinstance(other, Compose) and other.envelope is None:
return Compose(envelope, (self,) + other.diagrams)
else:
return Compose(envelope, (self, other))
Expand Down
49 changes: 21 additions & 28 deletions chalk/envelope.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,30 +64,31 @@ def post_transform(
# Translation
diff = tx.dot(tx.scale_vec(u, 1 / tx.dot(v, v)), v)
return tx.np.asarray(after_linear - diff)
# return tx.np.max(out, axis=-1)


@tx.jit
def env(transform: tx.Affine, angles: tx.Angles, d: tx.V2_tC) -> tx.Array:
# Push the user batch dimensions to the left.
batch_shape = d.shape[:-2]
segments_shape = transform.shape[:-2]
return_shape = batch_shape + segments_shape[:-1]
if segments_shape[-1] == 0:
return tx.np.zeros(return_shape)
for _ in range(len(segments_shape)):
d = d[..., None, :, :]

pre = pre_transform(transform, d)
trans = arc_envelope(transform, angles, pre[0])
v= post_transform(pre[1], pre[2], pre[3], trans).max(-1) # type: ignore
assert v.shape == return_shape, f"{v.shape} {return_shape}"
return tx.np.asarray(v)

@dataclass
class Envelope(Transformable, Monoid, Batchable):
segment: BatchSegment

def __call__(self, direction: V2_t) -> Scalars:
def run(d): # type: ignore
if self.segment.angles.shape[0] == 0:
return tx.np.array(0.0)

@partial(tx.np.vectorize, signature="(a,3,3),(a,2)->()")
def env(t, ang): # type: ignore
v = Envelope.general_transform(
t, lambda x: arc_envelope(t, ang, x), d
).max()
return v

return env(self.segment.transform,
self.segment.angles)

run = tx.multi_vmap(run, len(direction.shape) - 2) # type: ignore
return run(direction) # type: ignore
def __call__(self, direction: V2_t) -> tx.Array:
return env(*self.segment.tuple(), direction)

def __add__(self: BatchEnvelope, other: BatchEnvelope) -> BatchEnvelope:
return Envelope(self.segment + other.segment)
Expand All @@ -98,9 +99,7 @@ def __add__(self: BatchEnvelope, other: BatchEnvelope) -> BatchEnvelope:

@property
def center(self) -> P2_t:
d = [
self(Envelope.all_dir[d]) for d in range(Envelope.all_dir.shape[0])
]
d = self(Envelope.all_dir)
return P2(
(-d[1] + d[0]) / 2,
(-d[3] + d[2]) / 2,
Expand Down Expand Up @@ -149,12 +148,6 @@ def to_segments(self, angle: int = 45) -> V2_t:
v = tx.polar(tx.np.arange(0, 361, angle) * 1.0)
return tx.scale_vec(v, self(v))

@staticmethod
def general_transform(
t: Affine, fn: Callable[[V2_t], Scalars], d: V2_t
) -> tx.ScalarsC:
pre = pre_transform(t, d)
return post_transform(pre[1], pre[2], pre[3], fn(pre[0])) # type: ignore

def apply_transform(self, t: Affine) -> Envelope:
return Envelope(self.segment.apply_transform(t[..., None, :, :]))
Expand Down Expand Up @@ -189,7 +182,7 @@ def visit_apply_transform(
"Defaults to pass over"
return diagram.diagram.accept(self, t @ diagram.transform)


@tx.jit
def get_envelope(self: Diagram, t: Optional[Affine] = None) -> Envelope:
# assert self.size() == ()
if t is None:
Expand Down
4 changes: 2 additions & 2 deletions chalk/layout.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
from typing import TYPE_CHECKING, Any, List, Optional, Tuple

import chalk.transform as tx
from chalk.backend.patch import Patch
from chalk.backend.patch import Patch, patch_from_prim
from chalk.monoid import Monoid
from chalk.style import StyleHolder
from chalk.transform import Affine
Expand Down Expand Up @@ -59,7 +59,7 @@ def layout(
s = s.apply_style(style)
if draw_height is None:
draw_height = height
patches = [Patch.from_prim(prim, style, draw_height) for prim in get_primitives(s)]
patches = [patch_from_prim(prim, style, draw_height) for prim in get_primitives(s)]
return patches, height, width


Expand Down
2 changes: 1 addition & 1 deletion chalk/style.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ class WidthType(Enum):
LOCAL = auto()
NORMALIZED = auto()


@tx.jit
def Style(
line_width_: Optional[PropLike] = None,
line_color_: Optional[ColorLike] = None,
Expand Down
115 changes: 57 additions & 58 deletions chalk/trace.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,35 +14,35 @@
from chalk.types import Diagram


@dataclass
class TraceDistances(Monoid):
distance: tx.Scalars
mask: tx.Mask
# @dataclass
# class TraceDistances(Monoid):
# distance: tx.Scalars
# mask: tx.Mask

def __iter__(self): # type: ignore
yield self.distance
yield self.mask
# def __iter__(self): # type: ignore
# yield self.distance
# yield self.mask

def tuple(self) -> Tuple[tx.Scalars, tx.Mask]:
return self.distance, self.mask
# def tuple(self) -> Tuple[tx.Scalars, tx.Mask]:
# return self.distance, self.mask

def __getitem__(self, i: int): # type: ignore
if i == 0:
return self.distance
if i == 1:
return self.mask
# def __getitem__(self, i: int): # type: ignore
# if i == 0:
# return self.distance
# if i == 1:
# return self.mask

def __add__(self, other: TraceDistances) -> TraceDistances: # type: ignore
return TraceDistances(*tx.union(self.tuple(), other.tuple()))
# def __add__(self, other: TraceDistances) -> TraceDistances: # type: ignore
# return TraceDistances(*tx.union(self.tuple(), other.tuple()))

@staticmethod
def empty() -> TraceDistances:
return TraceDistances(tx.np.asarray([]), tx.np.asarray([]))
# @staticmethod
# def empty() -> TraceDistances:
# return TraceDistances(tx.np.asarray([]), tx.np.asarray([]))

def reduce(self, axis: int = 0) -> TraceDistances:
return TraceDistances(
*tx.union_axis((self.distance, self.mask), axis=axis)
)
# def reduce(self, axis: int = 0) -> TraceDistances:
# return TraceDistances(
# *tx.union_axis((self.distance, self.mask), axis=axis)
# )


@dataclass
Expand All @@ -58,21 +58,25 @@ class Trace(Monoid, Transformable):
segment: Segment

def __call__(self, point: P2_t, direction: V2_t) -> TraceDistances:
if len(point.shape) == 2:
point = point.reshape(1, 3, 1)
if len(direction.shape) == 2:
direction = direction.reshape(1, 3, 1)
assert point[..., -1, 0] == 1.0, point
point, direction = tx.np.broadcast_arrays(point, direction)

# Push the __call__ batch dimensions to the left.
batch_shape = point.shape[:-2]
print(self.segment.transform.shape)
segments_shape = self.segment.transform.shape[:-2]
for _ in range(len(segments_shape)):
point = point[..., None, :, :]
direction = direction[..., None, :, :]

d, m = Trace.general_transform(
self.segment.transform,
lambda x: arc_trace(self.segment, x),
lambda x1, x2: arc_trace(self.segment.transform, self.segment.angles, x1, x2),
Ray(point, direction),
)

ad = tx.np.argsort(d + (1 - m) * 1e10, axis=1)
d = tx.np.take_along_axis(d, ad, axis=1)
m = tx.np.take_along_axis(m, ad, axis=1)
return TraceDistances(d, m)
ad = tx.np.argsort(d + (1 - m) * 1e10, axis=-1)
d = tx.np.take_along_axis(d, ad, axis=-1)
m = tx.np.take_along_axis(m, ad, axis=-1)
return (d, m)

@staticmethod
def general_transform(
Expand All @@ -82,16 +86,11 @@ def general_transform(

def wrapped(
ray: Ray,
) -> TraceDistances:
td = TraceDistances(
*fn(
Ray(
t1 @ ray.pt[..., None, :, :],
t1 @ ray.v[..., None, :, :],
)
)
)
return td.reduce(axis=-1)
) :
d, m = fn(t1 @ ray.pt, t1 @ ray.v)
d = d.reshape(d.shape[:-2] + (-1,))
m = m.reshape(m.shape[:-2] + (-1,))
return d, m

return wrapped(r)

Expand All @@ -106,29 +105,29 @@ def trace_v(self, p: P2_t, v: V2_t) -> TraceDistances:
ad = tx.np.argsort(dists + (1 - m) * 1e10, axis=1)
m = tx.np.take_along_axis(m, ad, axis=1)
s = d[:, 0]
return TraceDistances(s[..., None] * v, m[:, 0])
return (s[..., None] * v, m[:, 0])

def trace_p(self, p: P2_t, v: V2_t) -> TraceDistances:
u, m = self.trace_v(p, v)
return TraceDistances(p + u, m)
return (p + u, m)

def max_trace_v(self, p: P2_t, v: V2_t) -> TraceDistances:
return self.trace_v(p, -v)

def max_trace_p(self, p: P2_t, v: V2_t) -> TraceDistances:
u, m = self.max_trace_v(p, v)
return TraceDistances(p + u, m)

@staticmethod
def combine(p1: TraceDistances, p2: TraceDistances) -> TraceDistances:
ps, m = p1
ps2, m2 = p2
ps = tx.np.concatenate([ps, ps2], axis=1)
m = tx.np.concatenate([m, m2], axis=1)
ad = tx.np.argsort(ps + (1 - m) * 1e10, axis=1)
ps = tx.np.take_along_axis(ps, ad, axis=1)
m = tx.np.take_along_axis(m, ad, axis=1)
return TraceDistances(ps, m)
return (p + u, m)

# @staticmethod
# def combine(p1: TraceDistances, p2: TraceDistances) -> TraceDistances:
# ps, m = p1
# ps2, m2 = p2
# ps = tx.np.concatenate([ps, ps2], axis=1)
# m = tx.np.concatenate([m, m2], axis=1)
# ad = tx.np.argsort(ps + (1 - m) * 1e10, axis=1)
# ps = tx.np.take_along_axis(ps, ad, axis=1)
# m = tx.np.take_along_axis(m, ad, axis=1)
# return TraceDistances(ps, m)


class GetLocatedSegments(DiagramVisitor[Segment, Affine]):
Expand Down
4 changes: 2 additions & 2 deletions chalk/trail.py
Original file line number Diff line number Diff line change
Expand Up @@ -140,7 +140,7 @@ def hrule(length: Floating) -> Trail:
def vrule(length: Floating) -> Trail:
return seg(length * tx.unit_y)


@tx.jit
def square() -> Trail:
t = seg(tx.unit_x) + seg(tx.unit_y)
return (t + t.rotate_by(0.5)).close()
Expand All @@ -160,7 +160,7 @@ def rounded_rectangle(
) + seg(0.01 * tx.unit_y)
return trail.close()


@tx.jit
def circle(clockwise: bool = True) -> Trail:
sides = 4
dangle = -90
Expand Down
Loading

0 comments on commit c00c4f2

Please sign in to comment.