Skip to content

Commit

Permalink
added weights to fit
Browse files Browse the repository at this point in the history
  • Loading branch information
Alex Al-Saffar committed Nov 19, 2023
1 parent 7698e97 commit c6e8ab4
Show file tree
Hide file tree
Showing 3 changed files with 9 additions and 22 deletions.
2 changes: 1 addition & 1 deletion myresources/crocodile/cluster/loader_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -495,7 +495,7 @@ def write_log(self, log: dict[JOB_STATUS, 'pd.DataFrame']):
def fetch_cloud_live(self):
remote = CloudManager.base_path
localpath = tb.P.tmp().joinpath(f"tmp_dirs/cloud_manager_live").create()
alternative_base = remote.from_cloud(cloud=self.cloud, rel2home=True, localpath=localpath.delete(sure=True), verbose=False)
alternative_base = localpath.delete(sure=True).from_cloud(cloud=self.cloud, remotepath=remote.get_remote_path(root="myhome", rel2home=True), verbose=False)
return alternative_base
@staticmethod
def prepare_servers_report(cloud_root: tb.P):
Expand Down
27 changes: 7 additions & 20 deletions myresources/crocodile/deeplearning.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,40 +345,27 @@ def fit(self, viz: bool = True, weight_name: Optional[str] = None,
validation_freq: int = 1, workers: int = 1, use_multiprocessing: bool = False,
**kwargs: Any):
assert self.data.split is not None, "Split your data before you start fitting."
<<<<<<< HEAD
x_train = [self.data.split[item] for item in self.data.specs.get_split_strings(self.data.specs.ip_names, which_split="train")]
y_train = [self.data.split[item] for item in self.data.specs.get_split_strings(self.data.specs.op_names, which_split="train")]
x_test = [self.data.split[item] for item in self.data.specs.get_split_strings(self.data.specs.ip_names, which_split="test")]
y_test = [self.data.split[item] for item in self.data.specs.get_split_strings(self.data.specs.op_names, which_split="test")]
x_train = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.ip_names, which_split="train")]
y_train = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.op_names, which_split="train")]
x_test = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.ip_names, which_split="test")]
y_test = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.op_names, which_split="test")]
if weight_name is not None:
assert weight_name in self.data.specs.other_names, f"weight_string must be one of {self.data.specs.other_names}"
if sample_weight is None:
train_weight_str = self.data.specs.get_split_strings(strings=[weight_name], which_split="train")[0]
train_weight_str = self.data.specs.get_split_names(strings=[weight_name], which_split="train")[0]
sample_weight = self.data.split[train_weight_str]
else:
print(f"⚠️ sample_weight is passed directly to `fit` method, ignoring `weight_string` argument.")
if val_sample_weight is None:
test_weight_str = self.data.specs.get_split_strings(strings=[weight_name], which_split="test")[0]
test_weight_str = self.data.specs.get_split_names(strings=[weight_name], which_split="test")[0]
val_sample_weight = self.data.split[test_weight_str]
else:
print(f"⚠️ val_sample_weight is passed directly to `fit` method, ignoring `weight_string` argument.")
=======
x_train = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.ip_names, which_split="train")]
y_train = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.op_names, which_split="train")]
x_test = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.ip_names, which_split="test")]
y_test = [self.data.split[item] for item in self.data.specs.get_split_names(self.data.specs.op_names, which_split="test")]
if weight_name is not None:
assert weight_name in self.data.specs.other_names, f"weight_string must be one of {self.data.specs.other_names}"
train_weight_str = self.data.specs.get_split_names(strings=[weight_name], which_split="train")[0]
test_weight_str = self.data.specs.get_split_names(strings=[weight_name], which_split="test")[0]
sample_weight = self.data.split[train_weight_str]
val_sample_weight = self.data.split[test_weight_str]
>>>>>>> a395a8308e7c841c7f6ae70fd8da5aaee8522d44

x_test = x_test[0] if len(x_test) == 1 else x_test
y_test = y_test[0] if len(y_test) == 1 else y_test
default_settings: dict[str, Any] = dict(x=x_train[0] if len(x_train) == 1 else x_train,
y=y_train[z0] if len(y_train) == 1 else y_train,
y=y_train[0] if len(y_train) == 1 else y_train,
validation_data=(x_test, y_test) if val_sample_weight is None else (x_test, y_test, val_sample_weight),
batch_size=self.hp.batch_size, epochs=self.hp.epochs, verbose=1, shuffle=self.hp.shuffle,
)
Expand Down
2 changes: 1 addition & 1 deletion myresources/crocodile/file_management.py
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def unzip(self, folder: OPLike = None, fname: OPLike = None, verbose: bool = Tru
else: List([x for x in __import__("zipfile").ZipFile(self.str).namelist() if "/" not in x or (len(x.split('/')) == 2 and x.endswith("/"))]).apply(lambda item: P(folder).joinpath(fname or "", item.replace("/", "")).delete(sure=True, verbose=True))
result = unzip(zipfile.str, str(folder), None if fname is None else P(fname).as_posix(), **kwargs)
assert isinstance(result, P)
return self._return(result, inlieu=False, inplace=inplace, operation="delete", orig=orig, verbose=verbose, msg=f"UNZIPPED {repr(zipfile)} ==> {repr(result)}")
return self._return(P(result), inlieu=False, inplace=inplace, operation="delete", orig=orig, verbose=verbose, msg=f"UNZIPPED {repr(zipfile)} ==> {repr(result)}")
def tar(self, folder: OPLike = None, name: OPLike = None, path: OPLike = None, inplace: bool = False, orig: bool = False, verbose: bool = True) -> 'P':
op_path = self._resolve_path(folder, name, path, self.name + ".tar").expanduser().resolve()
tar(self.expanduser().resolve().str, op_path=op_path.str); return self._return(op_path, inlieu=False, inplace=inplace, operation="delete", orig=orig, verbose=verbose, msg=f"TARRED {repr(self)} ==> {repr(op_path)}")
Expand Down

0 comments on commit c6e8ab4

Please sign in to comment.