Skip to content

Commit

Permalink
see issue #34. Adapted the prediction interval formulas
Browse files Browse the repository at this point in the history
  • Loading branch information
jkitchin committed Jan 3, 2024
1 parent 4f50345 commit 82111d0
Show file tree
Hide file tree
Showing 11 changed files with 59 additions and 106 deletions.
4 changes: 2 additions & 2 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,13 @@ repos:
rev: 22.3.0
hooks:
- id: black
entry: black --line-length 80 pycse
entry: black --line-length 100 pycse

- repo: https://github.com/PyCQA/flake8
rev: 4.0.1
hooks:
- id: flake8
entry: flake8 --max-line-length 80 pycse
entry: flake8 --max-line-length 100 pycse

- repo: local
hooks:
Expand Down
5 changes: 3 additions & 2 deletions docker/Dockerfile
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# https://jupyter-docker-stacks.readthedocs.io/en/latest/
# Note this takes about 30 minutes to build

FROM jupyter/scipy-notebook:python-3.9
FROM jupyter/scipy-notebook:python-3.10
MAINTAINER John Kitchin <[email protected]>

# Set the default shell to bash instead of sh so the source commands work
Expand Down Expand Up @@ -34,7 +34,8 @@ RUN python -m pip install --upgrade pip \
jupyterlab-git jupyter-videochat \
jupyterlab-spellchecker \
jupyterlab-code-formatter \
pydotplus ase \
pydotplus ase \
parsl \
&& jupyter labextension install plotlywidget \
&& pip install jupyterlab-link-share \
&& jupyter server extension disable nbclassic
Expand Down
37 changes: 28 additions & 9 deletions pycse/PYCSE.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,8 @@ def predict(X, y, pars, XX, alpha=0.05, ub=1e-5, ef=1.05):
ub : upper bound for smallest allowed Hessian eigenvalue
ef : eigenvalue factor for scaling Hessian
See See https://en.wikipedia.org/wiki/Prediction_interval#Unknown_mean,_unknown_variance
Returns
y, yint, pred_se
y : the predicted values
Expand All @@ -158,20 +160,29 @@ def predict(X, y, pars, XX, alpha=0.05, ub=1e-5, ef=1.05):

errs = y - X @ pars
sse = errs @ errs
rmse = sse / n
mse = sse / n

gprime = XX
hat = 2 * X.T @ X # hessian
eps = max(ub, ef * np.linalg.eigvals(hat).min())

# Scaled Fisher information
I_fisher = rmse * np.linalg.pinv(hat + np.eye(npars) * eps)
I_fisher = np.linalg.pinv(hat + np.eye(npars) * eps)

pred_se = np.diag(gprime @ I_fisher @ gprime.T) ** 0.5
pred_se = np.sqrt(mse * np.diag(gprime @ I_fisher @ gprime.T))
tval = t.ppf(1.0 - alpha / 2.0, dof)

yy = XX @ pars
return (yy, np.array([yy + tval * pred_se, yy - tval * pred_se]).T, pred_se)
return (
yy,
np.array(
[
yy + tval * pred_se * (1 + 1 / n),
yy - tval * pred_se * (1 + 1 / n),
]
).T,
pred_se,
)


# * Nonlinear regression
Expand Down Expand Up @@ -249,6 +260,8 @@ def nlpredict(X, y, model, loss, popt, xnew, alpha=0.05, ub=1e-5, ef=1.05):
This function uses numdifftools for the Hessian and Jacobian.
See https://en.wikipedia.org/wiki/Prediction_interval#Unknown_mean,_unknown_variance
Returns
-------
Expand All @@ -266,18 +279,24 @@ def nlpredict(X, y, model, loss, popt, xnew, alpha=0.05, ub=1e-5, ef=1.05):
eps = max(ub, ef * np.linalg.eigvals(hessp).min())

sse = loss(*popt)
rmse = sse / len(y)
I_fisher = rmse * np.linalg.pinv(hessp + np.eye(len(popt)) * eps)
n = len(y)
mse = sse / n
I_fisher = np.linalg.pinv(hessp + np.eye(len(popt)) * eps)

gprime = nd.Jacobian(lambda p: model(xnew, *p))(popt)

uncerts = np.sqrt(np.diag(gprime @ I_fisher @ gprime.T))
sigmas = np.sqrt(mse * np.diag(gprime @ I_fisher @ gprime.T))
tval = t.ppf(1 - alpha / 2, len(y) - len(popt))

return [
ypred,
np.array([ypred + tval * uncerts, ypred - tval * uncerts]).T,
uncerts,
np.array(
[
ypred + tval * sigmas * (1 + 1 / n),
ypred - tval * sigmas * (1 + 1 / n),
]
).T,
sigmas,
]


Expand Down
27 changes: 6 additions & 21 deletions pycse/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,7 @@ def pycse():
"""CLI to launch a Docker image with Jupyter lab in the CWD.
This assumes you have a working Docker installation."""
if shutil.which("docker") is None:
raise Exception(
"docker was not found."
" Please install it from https://www.docker.com/"
)
raise Exception("docker was not found." " Please install it from https://www.docker.com/")

# Check setup and get image if needed
try:
Expand All @@ -52,31 +49,19 @@ def pycse():
capture_output=True,
)
except:
subprocess.run(
["docker", "pull", "jkitchin/pycse"], capture_output=True
)
subprocess.run(["docker", "pull", "jkitchin/pycse"], capture_output=True)

# Check if the container is already running
p = subprocess.run(
["docker", "ps", "--format", '"{{.Names}}"'], capture_output=True
)
p = subprocess.run(["docker", "ps", "--format", '"{{.Names}}"'], capture_output=True)
if "pycse" in p.stdout.decode("utf-8"):
ans = input(
"There is already a pycse container running."
"Do you want to kill it? (y/n)"
)
ans = input("There is already a pycse container running." "Do you want to kill it? (y/n)")
if ans.lower() == "y":
subprocess.run("docker rm -f pycse".split())
else:
print(
"There can only be one pycse container running at a time."
"Connecting to it."
)
print("There can only be one pycse container running at a time." "Connecting to it.")

# this outputs something like 0.0.0.0:8987
p = subprocess.run(
"docker port pycse 8888".split(), capture_output=True
)
p = subprocess.run("docker port pycse 8888".split(), capture_output=True)
output = p.stdout.decode("utf-8").strip()
PORT = output.split(":")[-1]

Expand Down
39 changes: 9 additions & 30 deletions pycse/colab.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,15 +47,9 @@ def gdrive():

def aptupdate():
"""Run apt-get update."""
s = subprocess.run(
["apt-get", "update"], stdout=subprocess.PIPE, stderr=subprocess.PIPE
)
s = subprocess.run(["apt-get", "update"], stdout=subprocess.PIPE, stderr=subprocess.PIPE)
if s.returncode != 0:
raise Exception(
f"apt-get update failed.\n"
f"{s.stdout.decode()}\n"
f"{s.stderr.decode()}"
)
raise Exception(f"apt-get update failed.\n" f"{s.stdout.decode()}\n" f"{s.stderr.decode()}")


def aptinstall(apt_pkg):
Expand All @@ -68,9 +62,7 @@ def aptinstall(apt_pkg):
)
if s.returncode != 0:
raise Exception(
f"{apt_pkg} installation failed.\n"
f"{s.stdout.decode()}\n"
f"{s.stderr.decode()}"
f"{apt_pkg} installation failed.\n" f"{s.stdout.decode()}\n" f"{s.stderr.decode()}"
)


Expand Down Expand Up @@ -117,9 +109,7 @@ def notebook_string(fid):
return ipynb


def pdf_from_html(
pdf=None, verbose=False, plotly=False, javascript_delay=10000
):
def pdf_from_html(pdf=None, verbose=False, plotly=False, javascript_delay=10000):
"""Export the current notebook as a PDF.
pdf is the name of the PDF to export.
Expand Down Expand Up @@ -344,18 +334,14 @@ def fid_from_url(url):
return p

# https://docs.google.com/spreadsheets/d/1qSaBe73Pd8L3jJyOL68klp6yRArW7Nce/edit#gid=1923176268
elif (u.netloc == "docs.google.com") and u.path.startswith(
"/spreadsheets/d/"
):
elif (u.netloc == "docs.google.com") and u.path.startswith("/spreadsheets/d/"):
p = u.path
p = p.replace("/spreadsheets/d/", "")
p = p.replace("/edit", "")
return p

# https://docs.google.com/presentation/d/1poP1gvWlfeZCR_5FsIzlRPMAYlBUR827wKPjbWGzW9M/edit#slide=id.p
elif (u.netloc == "docs.google.com") and u.path.startswith(
"/presentation/d/"
):
elif (u.netloc == "docs.google.com") and u.path.startswith("/presentation/d/"):
p = u.path
p = p.replace("/presentation/d/", "")
p = p.replace("/edit", "")
Expand Down Expand Up @@ -600,11 +586,7 @@ def gdownload(*FILES, **kwargs):
stderr=subprocess.PIPE,
)
if s.returncode != 0:
print(
f"zip did not fully succeed:\n"
f"{s.stdout.decode()}\n"
f"{s.stderr.decode()}\n"
)
print(f"zip did not fully succeed:\n" f"{s.stdout.decode()}\n" f"{s.stderr.decode()}\n")
files.download(zipfile)


Expand Down Expand Up @@ -666,9 +648,7 @@ def gsuite(fid_or_url, width=1200, height=1000):
# Assume we have an fid
x = (
drive_service.files()
.get(
fileId=fid_or_url, supportsAllDrives=True, fields="webViewLink"
)
.get(fileId=fid_or_url, supportsAllDrives=True, fields="webViewLink")
.execute()
)
url = x.get("webViewLink", "No web link found.")
Expand All @@ -679,8 +659,7 @@ def gsuite(fid_or_url, width=1200, height=1000):
xframeoptions = g.headers.get("X-Frame-Options", "").lower()
if xframeoptions in ["deny", "sameorigin"]:
print(
f"X-Frame-Option = {xframeoptions}\n"
f"Embedding in IFrame is not allowed for {url}."
f"X-Frame-Option = {xframeoptions}\n" f"Embedding in IFrame is not allowed for {url}."
)
return None
else:
Expand Down
9 changes: 2 additions & 7 deletions pycse/hashcache.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,9 +115,7 @@ def get_hash(func, args, kwargs):
[
func.__code__.co_name, # This is the function name
func.__code__.co_code, # this is the function bytecode
get_standardized_args(
func, args, kwargs
), # The args used, including defaults
get_standardized_args(func, args, kwargs), # The args used, including defaults
],
hash_name="sha1",
)
Expand Down Expand Up @@ -193,10 +191,7 @@ def wrapper(func, *args, **kwargs):
# is a problem here. We just warn the user. Nothing else makes
# sense, the mutability may be intentional.
if not hsh == get_hash(func, args, kwargs):
print(
"WARNING something mutated, future"
" calls will not use the cache."
)
print("WARNING something mutated, future" " calls will not use the cache.")

# Try a bunch of ways to get a username.
try:
Expand Down
11 changes: 2 additions & 9 deletions pycse/lisp.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,16 +63,9 @@ def get_dict(obj):

def lispify(L):
"""Convert a Python object L to a lisp representation."""
if (
isinstance(L, str)
or isinstance(L, float)
or isinstance(L, int)
or isinstance(L, np.int64)
):
if isinstance(L, str) or isinstance(L, float) or isinstance(L, int) or isinstance(L, np.int64):
return L.lisp
elif (
isinstance(L, list) or isinstance(L, tuple) or isinstance(L, np.ndarray)
):
elif isinstance(L, list) or isinstance(L, tuple) or isinstance(L, np.ndarray):
s = [element.lisp for element in L]
return "(" + " ".join(s) + ")"
elif isinstance(L, dict):
Expand Down
4 changes: 1 addition & 3 deletions pycse/obipython.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,7 @@ class OrgFormatter(IPython.core.formatters.BaseFormatter):
def __call__(self, obj):
"""Call function for the class."""
try:
return tabulate(
obj, headers="keys", tablefmt="orgtbl", showindex="always"
)
return tabulate(obj, headers="keys", tablefmt="orgtbl", showindex="always")
# I am not sure what exceptions get thrown, or why this is here.
except: # noqa: E722
return None
Expand Down
8 changes: 2 additions & 6 deletions pycse/orgmode.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,9 +159,7 @@ def _repr_mimebundle_(self, include, exclude, **kwargs):
class Table:
"""A Table object for org."""

def __init__(
self, data, headers=None, caption=None, name=None, attributes=()
):
def __init__(self, data, headers=None, caption=None, name=None, attributes=()):
"""Initialize a table."""
self.data = data
self.headers = headers
Expand Down Expand Up @@ -212,9 +210,7 @@ class OrgFormatter(IPython.core.formatters.BaseFormatter):
ip = get_ipython()
ip.display_formatter.formatters["text/org"] = OrgFormatter()
ytv_f = ip.display_formatter.formatters["text/org"]
ytv_f.for_type_by_name(
"IPython.lib.display", "YouTubeVideo", lambda V: f"{V.src}"
)
ytv_f.for_type_by_name("IPython.lib.display", "YouTubeVideo", lambda V: f"{V.src}")
# get_ipython is not defined for tests I think.
except NameError:
pass
12 changes: 2 additions & 10 deletions pycse/orgmode_v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,11 +148,7 @@ def table(data, name=None, caption=None, attributes=None, none=""):
if row is None:
s += ["|-"]
else:
s += [
"| "
+ " | ".join([str(x) if x is not None else none for x in row])
+ "|"
]
s += ["| " + " | ".join([str(x) if x is not None else none for x in row]) + "|"]

print("\n".join(s))

Expand Down Expand Up @@ -208,11 +204,7 @@ def comment(s):
if "\n" in str(s):
print("\n#+begin_comment\n{}\n#+end_comment\n".format(s))
else:
print(
textwrap.fill(
s, initial_indent="# ", subsequent_indent="# ", width=79
)
)
print(textwrap.fill(s, initial_indent="# ", subsequent_indent="# ", width=79))


def fixed_width(s):
Expand Down
9 changes: 2 additions & 7 deletions pycse/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,9 +77,7 @@ def read_gsheet(url, *args, **kwargs):
The url should be viewable by anyone with the link.
"""
u = urlparse(url)
if not (u.netloc == "docs.google.com") and u.path.startswith(
"/spreadsheets/d/"
):
if not (u.netloc == "docs.google.com") and u.path.startswith("/spreadsheets/d/"):
raise Exception(f"{url} does not seem to be for a sheet")

fid = u.path.split("/")[3]
Expand All @@ -90,9 +88,6 @@ def read_gsheet(url, *args, **kwargs):
# default to main sheet
gid = 0

purl = (
"https://docs.google.com/spreadsheets/d/"
f"{fid}/export?format=csv&gid={gid}"
)
purl = "https://docs.google.com/spreadsheets/d/" f"{fid}/export?format=csv&gid={gid}"

return pd.read_csv(purl, *args, **kwargs)

0 comments on commit 82111d0

Please sign in to comment.