Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] adding an initial stab at an equinox module that wraps ipctk. #101

Open
wants to merge 3 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions .github/workflows/ci-build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -28,17 +28,17 @@ jobs:
sudo apt-get install -y libsuitesparse-dev
python -m pip install --upgrade pip
# pip install .
pip install -e ".[sparse]"
pip install -e ".[ipc,sparse]"
# pip install flake8 pytest
# if [ -f requirements.txt ]; then pip install -r requirements-test.txt; fi
- name: pytest
run: |
pip install -e ".[sparse,test]"
pip install -e ".[ipc,sparse,test]"
python -m pytest -n auto optimism --cov=optimism -Wignore
# we can also add the flag -n auto for parallel testing
- name: docs
run: |
pip install -e ".[docs,sparse,test]"
pip install -e ".[docs,ipc,sparse,test]"
cd docs
sphinx-apidoc -o source/ ../optimism -P
make html
Expand Down
53 changes: 53 additions & 0 deletions examples/ipc/example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
from jax import custom_jvp
import jax
import jax.numpy as jnp


@custom_jvp
def f(x):
return jnp.sin(x)


@f.defjvp
def _f_jvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = f(x)
tangent_out = jnp.cos(x) * x_dot
return primal_out[0], tangent_out[0]


def f_jvp(x, v):
return _f_jvp(x, v)


# forward over reverse
@custom_jvp
def f_grad(x):
return jax.grad(f)(x)



@f_grad.defjvp
def _f_hvp(primals, tangents):
x, = primals
x_dot, = tangents
primal_out = f_grad(x)
tangent_out = -jnp.sin(x) * x_dot
return primal_out[0], tangent_out[0]

def f_hvp(x, v):
return jax.jvp(f_grad, x, v)

def f_hess(x):
return jax.hessian(f)((x,))


x = jnp.array([jnp.sqrt(2.)])
v = jnp.ones(1)
print(f(x))
print(f_grad(x))
# print(jax.grad(f)(x))
# print(hvp(x, v))
print(f_hvp((x,), (v,)))
print(f_hess(x))
129 changes: 129 additions & 0 deletions optimism/contact/IPCTK.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
from .. import QuadratureRule
from functools import partial
import equinox as eqx
import ipctk
import jax
import jax.numpy as jnp
import meshio
import numpy as np


class IPCTKContact(eqx.Module):
mesh: any
collision_mesh: any
collisions: any
dof_manager: any
max_neighbors: int
q_rule: QuadratureRule.QuadratureRule
dhat: float
potential: ipctk.BarrierPotential

def __init__(self, mesh_file, dof_manager, max_neighbors=3, q_degree=3, dhat=2e-3):
self.mesh = meshio.read(mesh_file)
self.max_neighbors = max_neighbors
q_rule = QuadratureRule.create_quadrature_rule_1D(q_degree)
q_rule = eqx.tree_at(lambda x: x.xigauss, q_rule, jnp.array(q_rule.xigauss))
self.q_rule = eqx.tree_at(lambda x: x.wgauss, q_rule, jnp.array(q_rule.wgauss))
rest_positions = self.mesh.points
faces = self.mesh.cells_dict['triangle']
edges = ipctk.edges(faces)
self.dhat = dhat
self.collision_mesh = ipctk.CollisionMesh(rest_positions, edges)
# self.collisions = self.update_collisions()
self.collisions = None
self.potential = ipctk.BarrierPotential(self.dhat)
self.dof_manager = dof_manager



def energy(self, U):
return _energy(self, U)

def gradient(self, U):
return _gradient(self, U)

def update_collisions(self, coords):
coords = np.array(coords)
collisions = ipctk.Collisions()
collisions.build(self.collision_mesh, coords, self.dhat) # performs culling to find only potential collisions with distances less than dhat
new_self = eqx.tree_at(lambda x: x.collisions, self, collisions)
return new_self
# return collisions
# def energy(self, Uu, p):
# return _energy(self, Uu, p)

# def gradient(self, Uu, p):
# return _gradient(self, Uu, p)

# def hvp(self, Uu, p, v):
# return jax.jvp(self.gradient, (Uu, p), (v, p))[1]

# annoying below since jvps don't really play nice with classes that well

@partial(jax.custom_jvp, nondiff_argnums=(0,))
def _energy(contact, U):
coords = contact.collision_mesh.rest_positions.copy()[:, 0:2]
curr_coords = U + coords
# collisions = contact.collisions(curr_coords)
collisions = contact.collisions
barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords)
return barrier_energy


@_energy.defjvp
def _jvp(contact, primals, tangents):
U, = primals
dU, = tangents
coords = contact.collision_mesh.rest_positions.copy()[:, 0:2]
curr_coords = U + coords
# collisions = contact.collisions(curr_coords)
collisions = contact.collisions
barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords)
barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords)
return barrier_energy, jnp.dot(barrier_grad, dU.flatten())


# @partial(jax.custom_jvp, nondiff_argnums=(0,))
def _gradient(contact, U):
return jax.grad(contact.energy)(U)


@partial(jax.custom_jvp, nondiff_argnums=(0,))
def _energy_old(contact, Uu, p):
U = contact.dof_manager.create_field(Uu, p[0])
coords = contact.collision_mesh.rest_positions.copy()[:, 0:2]
curr_coords = U + coords
collisions = contact.collisions(curr_coords)
barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords)
return barrier_energy


@_energy_old.defjvp
def _jvp_old(contact, primals, tangents):
Uu, p = primals
dUu, dp = tangents
U = contact.dof_manager.create_field(Uu, p[0])
coords = contact.collision_mesh.rest_positions.copy()[:, 0:2]
curr_coords = U + coords
collisions = contact.collisions(curr_coords)
barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords)
barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords)
return barrier_energy, jnp.dot(barrier_grad, dUu)


@partial(jax.custom_jvp, nondiff_argnums=(0,))
def _gradient_old(contact, Uu, p):
return jax.grad(contact.energy, argnums=0)(Uu, p)


@_gradient_old.defjvp
def _hvp_old(contact, primals, tangents):
Uu, p = primals
dUu, dp = tangents
U = contact.dof_manager.create_field(Uu, p[0])
coords = contact.collision_mesh.rest_positions.copy()[:, 0:2]
curr_coords = U + coords
collisions = contact.collisions(curr_coords)
barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords)
barrier_hess = contact.potential.hessian(collisions, contact.collision_mesh, curr_coords)
return barrier_grad, barrier_hess @ dUu
107 changes: 107 additions & 0 deletions optimism/contact/test/test_IPCTK.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,107 @@
import equinox as eqx
import ipctk
import jax
import jax.numpy as jnp
import meshio
import unittest
from optimism import VTKWriter
from optimism.JaxConfig import *
from optimism import FunctionSpace
from optimism import Mesh
from optimism import Objective
from optimism import QuadratureRule
from optimism.FunctionSpace import DofManager
from optimism.test.MeshFixture import MeshFixture
from optimism.contact import Contact, IPCTK
from optimism.contact.IPCTK import IPCTKContact


def write_vtk_mesh(mesh, meshName):
writer = VTKWriter.VTKWriter(mesh, baseFileName=meshName)
writer.write()


class TwoBodyICPTKContactFixture(MeshFixture):

def setUp(self):
self.targetDispGrad = np.array([[0.1, -0.2],[0.4, -0.1]])

m1 = self.create_mesh_and_disp(3, 5, [0.0, 1.0], [0.0, 1.0],
lambda x : self.targetDispGrad.dot(x), '1')

m2 = self.create_mesh_and_disp(2, 4, [1.001, 2.001], [0.0, 1.0],
lambda x : self.targetDispGrad.dot(x), '2')

self.mesh, _ = Mesh.combine_mesh(m1, m2)

order=1 # ipctk only works with 1st order meshes
self.mesh = Mesh.create_higher_order_mesh_from_simplex_mesh(self.mesh, order=order, copyNodeSets=False)

nodeSets = Mesh.create_nodesets_from_sidesets(self.mesh)
self.mesh = Mesh.mesh_with_nodesets(self.mesh, nodeSets)
quadRule = QuadratureRule.create_quadrature_rule_on_triangle(degree=2)
self.fs = FunctionSpace.construct_function_space(self.mesh, quadRule)
self.dofManager = DofManager(self.fs, dim=self.mesh.coords.shape[1], EssentialBCs=[])

write_vtk_mesh(self.mesh, 'mesh')

self.contact = IPCTKContact('mesh.vtk', self.dofManager)
self.disp = np.zeros(self.mesh.coords.shape)

def test_collisions_min_distance(self):

coords = self.contact.collision_mesh.rest_positions.copy() # just use rest positions
collision_mesh = self.contact.collision_mesh
# collisions = self.contact.collisions(coords)
self.contact = self.contact.update_collisions(coords)
collisions = self.contact.collisions

# test distance computation in initial positions (disp=0)
coords[:,0] += self.disp[:,0]
coords[:,1] += self.disp[:,1]
self.assertNear(np.sqrt(collisions.compute_minimum_distance(collision_mesh, coords)), 1e-3, 8)

# test distance computation in offset positions
index = (self.mesh.nodeSets['right1'], 0)
self.disp = self.disp.at[index].set(5e-4)
coords[:,0] += self.disp[:,0]
coords[:,1] += self.disp[:,1]
self.assertNear(np.sqrt(collisions.compute_minimum_distance(collision_mesh, coords)), 5e-4, 8)

def test_energy(self):
U = 0. * self.disp
# Uu = self.dofManager.get_unknown_values(np.zeros(self.mesh.coords.shape))
# p = Objective.Params()
# barrier_energy = self.contact.energy(Uu, p)
# barrier_energy = self.contact.energy(U)
# self.assertTrue(barrier_energy > 0.0)
# # barrier_grad = self.contact.gradient(Uu, p)
# barrier_grad = self.contact.gradient(U)

# U = self.contact.dof_manager.create_field(Uu, p[0])
coords = self.contact.collision_mesh.rest_positions.copy()[:, 0:2]
curr_coords = U + coords
# collisions = self.contact.collisions(curr_coords)
self.contact = self.contact.update_collisions(curr_coords)
collisions = self.contact.collisions

barrier_energy = self.contact.energy(U)
self.assertTrue(barrier_energy > 0.0)
barrier_grad = self.contact.gradient(U)
# barrier_grad_jit = jit(self.contact.gradient)(U)
true_grad = self.contact.potential.gradient(collisions, self.contact.collision_mesh, curr_coords)
self.assertArrayNear(barrier_grad.flatten(), true_grad, 12)
# self.assertArrayNear(barrier_grad_jit.flatten(), true_grad, 12)

# barrier_grad_jit = eqx.filter_jit(self.contact.gradient)(U)
barrier_grad_jit = eqx.filter_jit(IPCTK._gradient)(self.contact, U)
# print(barrier_grad_jit)

# v = jnp.ones(Uu.shape[0])
# barrier_hvp = self.contact.hvp(Uu, p, v)
# true_hess = self.contact.potential.hessian(collisions, self.contact.collision_mesh, curr_coords)
# true_hvp = true_hess @ v
# self.assertArrayNear(barrier_hvp, true_hvp, 12)

if __name__ == '__main__':
unittest.main()
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@
'netcdf4',
'scipy',],
#tests_require=[], # could put chex and pytest here
extras_require={'sparse': ['scikit-sparse'],
extras_require={'ipc': ['ipctk', 'meshio'],
'sparse': ['scikit-sparse'],
'test': ['pytest', 'pytest-cov', 'pytest-xdist'],
'docs': ['sphinx', 'sphinx-copybutton', 'sphinx-rtd-theme', 'sphinxcontrib-bibtex', 'sphinxcontrib-napoleon']},
python_requires='>=3.7',
Expand Down
Loading