From 63d36c33d2c533e97f0d54a344b94b3ce1cd666e Mon Sep 17 00:00:00 2001 From: "Craig M. Hamel" Date: Tue, 19 Nov 2024 09:38:24 -0700 Subject: [PATCH 1/3] adding an initial stab at an equinox module that wraps ipctk. --- examples/ipc/example.py | 53 +++++++++++++++++ optimism/contact/IPCTK.py | 85 +++++++++++++++++++++++++++ optimism/contact/test/test_IPCTK.py | 91 +++++++++++++++++++++++++++++ setup.py | 3 +- 4 files changed, 231 insertions(+), 1 deletion(-) create mode 100644 examples/ipc/example.py create mode 100644 optimism/contact/IPCTK.py create mode 100755 optimism/contact/test/test_IPCTK.py diff --git a/examples/ipc/example.py b/examples/ipc/example.py new file mode 100644 index 00000000..1c36d901 --- /dev/null +++ b/examples/ipc/example.py @@ -0,0 +1,53 @@ +from jax import custom_jvp +import jax +import jax.numpy as jnp + + +@custom_jvp +def f(x): + return jnp.sin(x) + + +@f.defjvp +def _f_jvp(primals, tangents): + x, = primals + x_dot, = tangents + primal_out = f(x) + tangent_out = jnp.cos(x) * x_dot + return primal_out[0], tangent_out[0] + + +def f_jvp(x, v): + return _f_jvp(x, v) + + +# forward over reverse +@custom_jvp +def f_grad(x): + return jax.grad(f)(x) + + + +@f_grad.defjvp +def _f_hvp(primals, tangents): + x, = primals + x_dot, = tangents + primal_out = f_grad(x) + tangent_out = -jnp.sin(x) * x_dot + return primal_out[0], tangent_out[0] + +def f_hvp(x, v): + return jax.jvp(f_grad, x, v) + +def f_hess(x): + return jax.hessian(f)((x,)) + + +x = jnp.array([jnp.sqrt(2.)]) +v = jnp.ones(1) +print(f(x)) +print(f_grad(x)) +# print(jax.grad(f)(x)) +# print(hvp(x, v)) +print(f_hvp((x,), (v,))) +print(f_hess(x)) \ No newline at end of file diff --git a/optimism/contact/IPCTK.py b/optimism/contact/IPCTK.py new file mode 100644 index 00000000..10c247c8 --- /dev/null +++ b/optimism/contact/IPCTK.py @@ -0,0 +1,85 @@ +from .. import QuadratureRule +from functools import partial +import equinox as eqx +import ipctk +import jax +import jax.numpy as jnp +import meshio + + +class IPCTKContact(eqx.Module): + mesh: any + collision_mesh: any + dof_manager: any + max_neighbors: int + q_rule: QuadratureRule.QuadratureRule + dhat: float + potential: ipctk.BarrierPotential + + def __init__(self, mesh_file, dof_manager, max_neighbors=3, q_degree=3, dhat=2e-3): + self.mesh = meshio.read(mesh_file) + self.max_neighbors = max_neighbors + self.q_rule = QuadratureRule.create_quadrature_rule_1D(q_degree) + + rest_positions = self.mesh.points + faces = self.mesh.cells_dict['triangle'] + edges = ipctk.edges(faces) + self.dhat = dhat + self.collision_mesh = ipctk.CollisionMesh(rest_positions, edges) + self.potential = ipctk.BarrierPotential(self.dhat) + self.dof_manager = dof_manager + + def collisions(self, coords): + collisions = ipctk.Collisions() + collisions.build(self.collision_mesh, coords, self.dhat) # performs culling to find only potential collisions with distances less than dhat + return collisions + + def energy(self, Uu, p): + return _energy(self, Uu, p) + + def gradient(self, Uu, p): + return _gradient(self, Uu, p) + + def hvp(self, Uu, p, v): + return jax.jvp(self.gradient, (Uu, p), (v, p))[1] + +# annoying below since jvps don't really play nice with classes that well + +@partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _energy(contact, Uu, p): + U = contact.dof_manager.create_field(Uu, p[0]) + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = contact.collisions(curr_coords) + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + return barrier_energy + +@_energy.defjvp +def _jvp(contact, primals, tangents): + Uu, p = primals + dUu, dp = tangents + U = contact.dof_manager.create_field(Uu, p[0]) + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = contact.collisions(curr_coords) + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords) + return barrier_energy, jnp.dot(barrier_grad, dUu) + + +@partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _gradient(contact, Uu, p): + return jax.grad(contact.energy, argnums=0)(Uu, p) + + +@_gradient.defjvp +def _hvp(contact, primals, tangents): + Uu, p = primals + dUu, dp = tangents + U = contact.dof_manager.create_field(Uu, p[0]) + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = contact.collisions(curr_coords) + barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords) + barrier_hess = contact.potential.hessian(collisions, contact.collision_mesh, curr_coords) + return barrier_grad, barrier_hess @ dUu diff --git a/optimism/contact/test/test_IPCTK.py b/optimism/contact/test/test_IPCTK.py new file mode 100755 index 00000000..8437788a --- /dev/null +++ b/optimism/contact/test/test_IPCTK.py @@ -0,0 +1,91 @@ +import ipctk +import jax +import jax.numpy as jnp +import meshio +import unittest +from optimism import VTKWriter +from optimism.JaxConfig import * +from optimism import FunctionSpace +from optimism import Mesh +from optimism import Objective +from optimism import QuadratureRule +from optimism.FunctionSpace import DofManager +from optimism.test.MeshFixture import MeshFixture +from optimism.contact import Contact, IPCTK +from optimism.contact.IPCTK import IPCTKContact + + +def write_vtk_mesh(mesh, meshName): + writer = VTKWriter.VTKWriter(mesh, baseFileName=meshName) + writer.write() + + +class TwoBodyICPTKContactFixture(MeshFixture): + + def setUp(self): + self.targetDispGrad = np.array([[0.1, -0.2],[0.4, -0.1]]) + + m1 = self.create_mesh_and_disp(3, 5, [0.0, 1.0], [0.0, 1.0], + lambda x : self.targetDispGrad.dot(x), '1') + + m2 = self.create_mesh_and_disp(2, 4, [1.001, 2.001], [0.0, 1.0], + lambda x : self.targetDispGrad.dot(x), '2') + + self.mesh, _ = Mesh.combine_mesh(m1, m2) + + order=1 # ipctk only works with 1st order meshes + self.mesh = Mesh.create_higher_order_mesh_from_simplex_mesh(self.mesh, order=order, copyNodeSets=False) + + nodeSets = Mesh.create_nodesets_from_sidesets(self.mesh) + self.mesh = Mesh.mesh_with_nodesets(self.mesh, nodeSets) + quadRule = QuadratureRule.create_quadrature_rule_on_triangle(degree=2) + self.fs = FunctionSpace.construct_function_space(self.mesh, quadRule) + self.dofManager = DofManager(self.fs, dim=self.mesh.coords.shape[1], EssentialBCs=[]) + + write_vtk_mesh(self.mesh, 'mesh') + + self.contact = IPCTKContact('mesh.vtk', self.dofManager) + self.disp = np.zeros(self.mesh.coords.shape) + + def test_collisions_min_distance(self): + + coords = self.contact.collision_mesh.rest_positions.copy() # just use rest positions + collision_mesh = self.contact.collision_mesh + collisions = self.contact.collisions(coords) + + # test distance computation in initial positions (disp=0) + coords[:,0] += self.disp[:,0] + coords[:,1] += self.disp[:,1] + self.assertNear(np.sqrt(collisions.compute_minimum_distance(collision_mesh, coords)), 1e-3, 8) + + # test distance computation in offset positions + index = (self.mesh.nodeSets['right1'], 0) + self.disp = self.disp.at[index].set(5e-4) + coords[:,0] += self.disp[:,0] + coords[:,1] += self.disp[:,1] + self.assertNear(np.sqrt(collisions.compute_minimum_distance(collision_mesh, coords)), 5e-4, 8) + + def test_energy(self): + # U = 0. * self.disp + Uu = self.dofManager.get_unknown_values(np.zeros(self.mesh.coords.shape)) + p = Objective.Params() + barrier_energy = self.contact.energy(Uu, p) + self.assertTrue(barrier_energy > 0.0) + barrier_grad = self.contact.gradient(Uu, p) + + U = self.contact.dof_manager.create_field(Uu, p[0]) + coords = self.contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = self.contact.collisions(curr_coords) + + true_grad = self.contact.potential.gradient(collisions, self.contact.collision_mesh, curr_coords) + self.assertArrayNear(barrier_grad, true_grad, 12) + + v = jnp.ones(Uu.shape[0]) + barrier_hvp = self.contact.hvp(Uu, p, v) + true_hess = self.contact.potential.hessian(collisions, self.contact.collision_mesh, curr_coords) + true_hvp = true_hess @ v + self.assertArrayNear(barrier_hvp, true_hvp, 12) + +if __name__ == '__main__': + unittest.main() diff --git a/setup.py b/setup.py index c4edecda..3fb9e05a 100644 --- a/setup.py +++ b/setup.py @@ -13,7 +13,8 @@ 'netcdf4', 'scipy',], #tests_require=[], # could put chex and pytest here - extras_require={'sparse': ['scikit-sparse'], + extras_require={'ipc': ['ipctk', 'meshio'], + 'sparse': ['scikit-sparse'], 'test': ['pytest', 'pytest-cov', 'pytest-xdist'], 'docs': ['sphinx', 'sphinx-copybutton', 'sphinx-rtd-theme', 'sphinxcontrib-bibtex', 'sphinxcontrib-napoleon']}, python_requires='>=3.7', From 3db33afa35826ba65b2b5fda2bb7917695431852 Mon Sep 17 00:00:00 2001 From: cmhamel <31457225+cmhamel@users.noreply.github.com> Date: Tue, 19 Nov 2024 11:58:20 -0500 Subject: [PATCH 2/3] Update ci-build.yml --- .github/workflows/ci-build.yml | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/.github/workflows/ci-build.yml b/.github/workflows/ci-build.yml index 94385938..5ccae342 100644 --- a/.github/workflows/ci-build.yml +++ b/.github/workflows/ci-build.yml @@ -28,17 +28,17 @@ jobs: sudo apt-get install -y libsuitesparse-dev python -m pip install --upgrade pip # pip install . - pip install -e ".[sparse]" + pip install -e ".[ipc,sparse]" # pip install flake8 pytest # if [ -f requirements.txt ]; then pip install -r requirements-test.txt; fi - name: pytest run: | - pip install -e ".[sparse,test]" + pip install -e ".[ipc,sparse,test]" python -m pytest -n auto optimism --cov=optimism -Wignore # we can also add the flag -n auto for parallel testing - name: docs run: | - pip install -e ".[docs,sparse,test]" + pip install -e ".[docs,ipc,sparse,test]" cd docs sphinx-apidoc -o source/ ../optimism -P make html From 9b83495315cd55abd11403d6936863c862e29228 Mon Sep 17 00:00:00 2001 From: "Craig M. Hamel" Date: Thu, 19 Dec 2024 13:05:12 -0700 Subject: [PATCH 3/3] some more attempts. --- optimism/contact/IPCTK.py | 78 ++++++++++++++++++++++------- optimism/contact/test/test_IPCTK.py | 50 +++++++++++------- 2 files changed, 94 insertions(+), 34 deletions(-) diff --git a/optimism/contact/IPCTK.py b/optimism/contact/IPCTK.py index 10c247c8..05633ce5 100644 --- a/optimism/contact/IPCTK.py +++ b/optimism/contact/IPCTK.py @@ -5,11 +5,13 @@ import jax import jax.numpy as jnp import meshio +import numpy as np class IPCTKContact(eqx.Module): mesh: any collision_mesh: any + collisions: any dof_manager: any max_neighbors: int q_rule: QuadratureRule.QuadratureRule @@ -19,43 +21,85 @@ class IPCTKContact(eqx.Module): def __init__(self, mesh_file, dof_manager, max_neighbors=3, q_degree=3, dhat=2e-3): self.mesh = meshio.read(mesh_file) self.max_neighbors = max_neighbors - self.q_rule = QuadratureRule.create_quadrature_rule_1D(q_degree) - + q_rule = QuadratureRule.create_quadrature_rule_1D(q_degree) + q_rule = eqx.tree_at(lambda x: x.xigauss, q_rule, jnp.array(q_rule.xigauss)) + self.q_rule = eqx.tree_at(lambda x: x.wgauss, q_rule, jnp.array(q_rule.wgauss)) rest_positions = self.mesh.points faces = self.mesh.cells_dict['triangle'] edges = ipctk.edges(faces) self.dhat = dhat self.collision_mesh = ipctk.CollisionMesh(rest_positions, edges) + # self.collisions = self.update_collisions() + self.collisions = None self.potential = ipctk.BarrierPotential(self.dhat) self.dof_manager = dof_manager - def collisions(self, coords): + + + def energy(self, U): + return _energy(self, U) + + def gradient(self, U): + return _gradient(self, U) + + def update_collisions(self, coords): + coords = np.array(coords) collisions = ipctk.Collisions() collisions.build(self.collision_mesh, coords, self.dhat) # performs culling to find only potential collisions with distances less than dhat - return collisions - - def energy(self, Uu, p): - return _energy(self, Uu, p) + new_self = eqx.tree_at(lambda x: x.collisions, self, collisions) + return new_self + # return collisions + # def energy(self, Uu, p): + # return _energy(self, Uu, p) - def gradient(self, Uu, p): - return _gradient(self, Uu, p) + # def gradient(self, Uu, p): + # return _gradient(self, Uu, p) - def hvp(self, Uu, p, v): - return jax.jvp(self.gradient, (Uu, p), (v, p))[1] + # def hvp(self, Uu, p, v): + # return jax.jvp(self.gradient, (Uu, p), (v, p))[1] # annoying below since jvps don't really play nice with classes that well @partial(jax.custom_jvp, nondiff_argnums=(0,)) -def _energy(contact, Uu, p): - U = contact.dof_manager.create_field(Uu, p[0]) +def _energy(contact, U): coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] curr_coords = U + coords - collisions = contact.collisions(curr_coords) + # collisions = contact.collisions(curr_coords) + collisions = contact.collisions barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) return barrier_energy + @_energy.defjvp def _jvp(contact, primals, tangents): + U, = primals + dU, = tangents + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + # collisions = contact.collisions(curr_coords) + collisions = contact.collisions + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + barrier_grad = contact.potential.gradient(collisions, contact.collision_mesh, curr_coords) + return barrier_energy, jnp.dot(barrier_grad, dU.flatten()) + + +# @partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _gradient(contact, U): + return jax.grad(contact.energy)(U) + + +@partial(jax.custom_jvp, nondiff_argnums=(0,)) +def _energy_old(contact, Uu, p): + U = contact.dof_manager.create_field(Uu, p[0]) + coords = contact.collision_mesh.rest_positions.copy()[:, 0:2] + curr_coords = U + coords + collisions = contact.collisions(curr_coords) + barrier_energy = contact.potential(collisions, contact.collision_mesh, curr_coords) + return barrier_energy + + +@_energy_old.defjvp +def _jvp_old(contact, primals, tangents): Uu, p = primals dUu, dp = tangents U = contact.dof_manager.create_field(Uu, p[0]) @@ -68,12 +112,12 @@ def _jvp(contact, primals, tangents): @partial(jax.custom_jvp, nondiff_argnums=(0,)) -def _gradient(contact, Uu, p): +def _gradient_old(contact, Uu, p): return jax.grad(contact.energy, argnums=0)(Uu, p) -@_gradient.defjvp -def _hvp(contact, primals, tangents): +@_gradient_old.defjvp +def _hvp_old(contact, primals, tangents): Uu, p = primals dUu, dp = tangents U = contact.dof_manager.create_field(Uu, p[0]) diff --git a/optimism/contact/test/test_IPCTK.py b/optimism/contact/test/test_IPCTK.py index 8437788a..71a8d716 100755 --- a/optimism/contact/test/test_IPCTK.py +++ b/optimism/contact/test/test_IPCTK.py @@ -1,3 +1,4 @@ +import equinox as eqx import ipctk import jax import jax.numpy as jnp @@ -51,7 +52,9 @@ def test_collisions_min_distance(self): coords = self.contact.collision_mesh.rest_positions.copy() # just use rest positions collision_mesh = self.contact.collision_mesh - collisions = self.contact.collisions(coords) + # collisions = self.contact.collisions(coords) + self.contact = self.contact.update_collisions(coords) + collisions = self.contact.collisions # test distance computation in initial positions (disp=0) coords[:,0] += self.disp[:,0] @@ -66,26 +69,39 @@ def test_collisions_min_distance(self): self.assertNear(np.sqrt(collisions.compute_minimum_distance(collision_mesh, coords)), 5e-4, 8) def test_energy(self): - # U = 0. * self.disp - Uu = self.dofManager.get_unknown_values(np.zeros(self.mesh.coords.shape)) - p = Objective.Params() - barrier_energy = self.contact.energy(Uu, p) - self.assertTrue(barrier_energy > 0.0) - barrier_grad = self.contact.gradient(Uu, p) - - U = self.contact.dof_manager.create_field(Uu, p[0]) + U = 0. * self.disp + # Uu = self.dofManager.get_unknown_values(np.zeros(self.mesh.coords.shape)) + # p = Objective.Params() + # barrier_energy = self.contact.energy(Uu, p) + # barrier_energy = self.contact.energy(U) + # self.assertTrue(barrier_energy > 0.0) + # # barrier_grad = self.contact.gradient(Uu, p) + # barrier_grad = self.contact.gradient(U) + + # U = self.contact.dof_manager.create_field(Uu, p[0]) coords = self.contact.collision_mesh.rest_positions.copy()[:, 0:2] curr_coords = U + coords - collisions = self.contact.collisions(curr_coords) + # collisions = self.contact.collisions(curr_coords) + self.contact = self.contact.update_collisions(curr_coords) + collisions = self.contact.collisions + barrier_energy = self.contact.energy(U) + self.assertTrue(barrier_energy > 0.0) + barrier_grad = self.contact.gradient(U) + # barrier_grad_jit = jit(self.contact.gradient)(U) true_grad = self.contact.potential.gradient(collisions, self.contact.collision_mesh, curr_coords) - self.assertArrayNear(barrier_grad, true_grad, 12) - - v = jnp.ones(Uu.shape[0]) - barrier_hvp = self.contact.hvp(Uu, p, v) - true_hess = self.contact.potential.hessian(collisions, self.contact.collision_mesh, curr_coords) - true_hvp = true_hess @ v - self.assertArrayNear(barrier_hvp, true_hvp, 12) + self.assertArrayNear(barrier_grad.flatten(), true_grad, 12) + # self.assertArrayNear(barrier_grad_jit.flatten(), true_grad, 12) + + # barrier_grad_jit = eqx.filter_jit(self.contact.gradient)(U) + barrier_grad_jit = eqx.filter_jit(IPCTK._gradient)(self.contact, U) + # print(barrier_grad_jit) + + # v = jnp.ones(Uu.shape[0]) + # barrier_hvp = self.contact.hvp(Uu, p, v) + # true_hess = self.contact.potential.hessian(collisions, self.contact.collision_mesh, curr_coords) + # true_hvp = true_hess @ v + # self.assertArrayNear(barrier_hvp, true_hvp, 12) if __name__ == '__main__': unittest.main()