diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 94385938..5ccae342 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -28,17 +28,17 @@ jobs: sudo apt-get install -y libsuitesparse-dev python -m pip install --upgrade pip # pip install . - pip install -e ".[sparse]" + pip install -e ".[ipc,sparse]" # pip install flake8 pytest # if [ -f requirements.txt ]; then pip install -r requirements-test.txt; fi - name: pytest run: | - pip install -e ".[sparse,test]" + pip install -e ".[ipc,sparse,test]" python -m pytest -n auto optimism --cov=optimism -Wignore # we can also add the flag -n auto for parallel testing - name: docs run: | - pip install -e ".[docs,sparse,test]" + pip install -e ".[docs,ipc,sparse,test]" cd docs sphinx-apidoc -o source/ ../optimism -P make html diff --git a/examples/ipc/example.py b/examples/ipc/example.py new file mode 100644 index 00000000..1c36d901 --- /dev/null +++ b/examples/ipc/example.py @@ -0,0 +1,53 @@ +from jax import custom_jvp +import jax +import jax.numpy as jnp + + +@custom_jvp +def f(x): + return jnp.sin(x) + + +@f.defjvp +def _f_jvp(primals, tangents): + x, = primals + x_dot, = tangents + primal_out = f(x) + tangent_out = jnp.cos(x) * x_dot + return primal_out[0], tangent_out[0] + + +def f_jvp(x, v): + return _f_jvp(x, v) + + +# forward over reverse +@custom_jvp +def f_grad(x): + return jax.grad(f)(x) + + + +@f_grad.defjvp +def _f_hvp(primals, tangents): + x, = primals + x_dot, = tangents + primal_out = f_grad(x) + tangent_out = -jnp.sin(x) * x_dot + return primal_out[0], tangent_out[0] + +def f_hvp(x, v): + return jax.jvp(f_grad, x, v) + +def f_hess(x): + return jax.hessian(f)((x,)) + + +x = jnp.array([jnp.sqrt(2.)]) +v = jnp.ones(1) +print(f(x)) +print(f_grad(x)) +# print(jax.grad(f)(x)) +# print(hvp(x, v)) +print(f_hvp((x,), (v,))) +print(f_hess(x)) \ No newline at end of file diff --git a/optimism/contact/IPCTK.py b/optimism/contact/IPCTK.py new file mode 100644 index 00000000..05633ce5 --- /dev/null +++ b/optimism/contact/IPCTK.py @@ -0,0 +1,129 @@ +from .. import QuadratureRule +from functools import partial +import equinox as eqx +import ipctk +import jax +import jax.numpy as jnp +import meshio +import numpy as np + + +class IPCTKContact(eqx.Module): + mesh: any + collision_mesh: any + collisions: any + dof_manager: any + max_neighbors: int + q_rule: QuadratureRule.QuadratureRule + dhat: float + potential: ipctk.BarrierPotential + + def __init__(self, mesh_file, dof_manager, max_neighbors=3, q_degree=3, dhat=2e-3): + self.mesh = meshio.read(mesh_file) + self.max_neighbors = max_neighbors + q_rule = QuadratureRule.create_quadrature_rule_1D(q_degree) + q_rule = eqx.tree_at(lambda x: x.xigauss, q_rule, jnp.array(q_rule.xigauss)) + self.q_rule = eqx.tree_at(lambda x: x.wgauss, q_rule, jnp.array(q_rule.wgauss)) + rest_positions = self.mesh.points + faces = self.mesh.cells_dict['triangle'] + edges = ipctk.edges(faces) + self.dhat = dhat + self.collision_mesh = ipctk.CollisionMesh(rest_positions, edges) + # self.collisions = self.update_collisions() + self.collisions = None + self.potential = ipctk.BarrierPotential(self.dhat) + self.dof_manager = dof_manager + + + + def energy(self, U): + return _energy(self, U) + + def gradient(self, U): + return _gradient(self, U) + + def update_collisions(self, coords): + coords = np.array(coords) + collisions = ipctk.Collisions() + collisions.build(self.collision_mesh, coords, self.dhat) # performs culling to find only potential collisions with distances less than dhat + new_self = eqx.tree_at(lambda x: x.collisions, self, collisions) + return new_self + # return collisions + # def energy(self, Uu, p): + # return _energy(self, Uu, p) + + # def gradient(self, Uu, p): + # return _gradient(self, Uu, p) + + # def hvp(self, Uu, p, v): + # return jax.jvp(self.gradient, (Uu, p), (v, p))[1] + +# annoying below since jvps don't really play nice with classes that well + +@partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _energy(contact, U): + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + # collisions = contact.collisions(curr_coords) + collisions = contact.collisions + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + return barrier_energy + + +@_energy.defjvp +def _jvp(contact, primals, tangents): + U, = primals + dU, = tangents + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + # collisions = contact.collisions(curr_coords) + collisions = contact.collisions + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords) + return barrier_energy, jnp.dot(barrier_grad, dU.flatten()) + + +# @partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _gradient(contact, U): + return jax.grad(contact.energy)(U) + + +@partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _energy_old(contact, Uu, p): + U = contact.dof_manager.create_field(Uu, p[0]) + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = contact.collisions(curr_coords) + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + return barrier_energy + + +@_energy_old.defjvp +def _jvp_old(contact, primals, tangents): + Uu, p = primals + dUu, dp = tangents + U = contact.dof_manager.create_field(Uu, p[0]) + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = contact.collisions(curr_coords) + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords) + return barrier_energy, jnp.dot(barrier_grad, dUu) + + +@partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _gradient_old(contact, Uu, p): + return jax.grad(contact.energy, argnums=0)(Uu, p) + + +@_gradient_old.defjvp +def _hvp_old(contact, primals, tangents): + Uu, p = primals + dUu, dp = tangents + U = contact.dof_manager.create_field(Uu, p[0]) + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = contact.collisions(curr_coords) + barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords) + barrier_hess = contact.potential.hessian(collisions, contact.collision_mesh, curr_coords) + return barrier_grad, barrier_hess @ dUu diff --git a/optimism/contact/test/test_IPCTK.py b/optimism/contact/test/test_IPCTK.py new file mode 100755 index 00000000..71a8d716 --- /dev/null +++ b/optimism/contact/test/test_IPCTK.py @@ -0,0 +1,107 @@ +import equinox as eqx +import ipctk +import jax +import jax.numpy as jnp +import meshio +import unittest +from optimism import VTKWriter +from optimism.JaxConfig import * +from optimism import FunctionSpace +from optimism import Mesh +from optimism import Objective +from optimism import QuadratureRule +from optimism.FunctionSpace import DofManager +from optimism.test.MeshFixture import MeshFixture +from optimism.contact import Contact, IPCTK +from optimism.contact.IPCTK import IPCTKContact + + +def write_vtk_mesh(mesh, meshName): + writer = VTKWriter.VTKWriter(mesh, baseFileName=meshName) + writer.write() + + +class TwoBodyICPTKContactFixture(MeshFixture): + + def setUp(self): + self.targetDispGrad = np.array([[0.1, -0.2],[0.4, -0.1]]) + + m1 = self.create_mesh_and_disp(3, 5, [0.0, 1.0], [0.0, 1.0], + lambda x : self.targetDispGrad.dot(x), '1') + + m2 = self.create_mesh_and_disp(2, 4, [1.001, 2.001], [0.0, 1.0], + lambda x : self.targetDispGrad.dot(x), '2') + + self.mesh, _ = Mesh.combine_mesh(m1, m2) + + order=1 # ipctk only works with 1st order meshes + self.mesh = Mesh.create_higher_order_mesh_from_simplex_mesh(self.mesh, order=order, copyNodeSets=False) + + nodeSets = Mesh.create_nodesets_from_sidesets(self.mesh) + self.mesh = Mesh.mesh_with_nodesets(self.mesh, nodeSets) + quadRule = QuadratureRule.create_quadrature_rule_on_triangle(degree=2) + self.fs = FunctionSpace.construct_function_space(self.mesh, quadRule) + self.dofManager = DofManager(self.fs, dim=self.mesh.coords.shape[1], EssentialBCs=[]) + + write_vtk_mesh(self.mesh, 'mesh') + + self.contact = IPCTKContact('mesh.vtk', self.dofManager) + self.disp = np.zeros(self.mesh.coords.shape) + + def test_collisions_min_distance(self): + + coords = self.contact.collision_mesh.rest_positions.copy() # just use rest positions + collision_mesh = self.contact.collision_mesh + # collisions = self.contact.collisions(coords) + self.contact = self.contact.update_collisions(coords) + collisions = self.contact.collisions + + # test distance computation in initial positions (disp=0) + coords[:,0] += self.disp[:,0] + coords[:,1] += self.disp[:,1] + self.assertNear(np.sqrt(collisions.compute_minimum_distance(collision_mesh, coords)), 1e-3, 8) + + # test distance computation in offset positions + index = (self.mesh.nodeSets['right1'], 0) + self.disp = self.disp.at[index].set(5e-4) + coords[:,0] += self.disp[:,0] + coords[:,1] += self.disp[:,1] + self.assertNear(np.sqrt(collisions.compute_minimum_distance(collision_mesh, coords)), 5e-4, 8) + + def test_energy(self): + U = 0. * self.disp + # Uu = self.dofManager.get_unknown_values(np.zeros(self.mesh.coords.shape)) + # p = Objective.Params() + # barrier_energy = self.contact.energy(Uu, p) + # barrier_energy = self.contact.energy(U) + # self.assertTrue(barrier_energy > 0.0) + # # barrier_grad = self.contact.gradient(Uu, p) + # barrier_grad = self.contact.gradient(U) + + # U = self.contact.dof_manager.create_field(Uu, p[0]) + coords = self.contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + # collisions = self.contact.collisions(curr_coords) + self.contact = self.contact.update_collisions(curr_coords) + collisions = self.contact.collisions + + barrier_energy = self.contact.energy(U) + self.assertTrue(barrier_energy > 0.0) + barrier_grad = self.contact.gradient(U) + # barrier_grad_jit = jit(self.contact.gradient)(U) + true_grad = self.contact.potential.gradient(collisions, self.contact.collision_mesh, curr_coords) + self.assertArrayNear(barrier_grad.flatten(), true_grad, 12) + # self.assertArrayNear(barrier_grad_jit.flatten(), true_grad, 12) + + # barrier_grad_jit = eqx.filter_jit(self.contact.gradient)(U) + barrier_grad_jit = eqx.filter_jit(IPCTK._gradient)(self.contact, U) + # print(barrier_grad_jit) + + # v = jnp.ones(Uu.shape[0]) + # barrier_hvp = self.contact.hvp(Uu, p, v) + # true_hess = self.contact.potential.hessian(collisions, self.contact.collision_mesh, curr_coords) + # true_hvp = true_hess @ v + # self.assertArrayNear(barrier_hvp, true_hvp, 12) + +if __name__ == '__main__': + unittest.main() diff --git a/setup.py b/setup.py index c4edecda..3fb9e05a 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,8 @@ 'netcdf4', 'scipy',], #tests_require=[], # could put chex and pytest here - extras_require={'sparse': ['scikit-sparse'], + extras_require={'ipc': ['ipctk', 'meshio'], + 'sparse': ['scikit-sparse'], 'test': ['pytest', 'pytest-cov', 'pytest-xdist'], 'docs': ['sphinx', 'sphinx-copybutton', 'sphinx-rtd-theme', 'sphinxcontrib-bibtex', 'sphinxcontrib-napoleon']}, python_requires='>=3.7',