diff --git a/miniwdl-plugins/s3upload/miniwdl_s3upload.py b/miniwdl-plugins/s3upload/miniwdl_s3upload.py index fb787f92..0d14e36a 100644 --- a/miniwdl-plugins/s3upload/miniwdl_s3upload.py +++ b/miniwdl-plugins/s3upload/miniwdl_s3upload.py @@ -89,6 +89,7 @@ def inode(link: str): _uploaded_files: Dict[Tuple[int, int], str] = {} _cached_files: Dict[Tuple[int, int], Tuple[str, Env.Bindings[Value.Base]]] = {} +_key_inputs: Dict[str, Env.Bindings[Value.Base]] = {} _uploaded_files_lock = threading.Lock() @@ -107,8 +108,18 @@ def cache(v: Union[Value.File, Value.Directory]) -> str: return _uploaded_files[inode(str(v.value))] remapped_outputs = Value.rewrite_env_paths(outputs, cache) + + input_digest = Value.digest_env( + Value.rewrite_env_paths( + _key_inputs[key], lambda v: _uploaded_files.get(inode(str(v.value)), str(v.value)) + ) + ) + key_parts = key.split('/') + key_parts[-1] = input_digest + s3_cache_key = "/".join(key_parts) + if not missing and cfg.has_option("s3_progressive_upload", "uri_prefix"): - uri = os.path.join(get_s3_put_prefix(cfg), "cache", f"{key}.json") + uri = os.path.join(get_s3_put_prefix(cfg), "cache", f"{s3_cache_key}.json") s3_object(uri).put(Body=json.dumps(values_to_json(remapped_outputs)).encode()) flag_temporary(uri) logger.info(_("call cache insert", cache_file=uri)) @@ -118,6 +129,15 @@ class CallCache(cache.CallCache): def get( self, key: str, inputs: Env.Bindings[Value.Base], output_types: Env.Bindings[Type.Base] ) -> Optional[Env.Bindings[Value.Base]]: + # HACK: in order to back the call cache in S3 we need to cache the S3 paths to the outputs. + # If we get a cache hit, those S3 paths will be passed to the next step. However, + # the cache key is computed using local inputs so this results in a cache miss. + # we need `put` to use a key based on S3 paths instead but put doesn't have access to step + # inputs. 'put' should always be run after a `get` is called so here we are storing the + # inputs based on the cache key so `put` can get the inputs. + global _key_inputs + _key_inputs[key] = inputs + if not self._cfg.has_option("s3_progressive_upload", "uri_prefix"): return super().get(key, inputs, output_types) uri = urlparse(get_s3_get_prefix(self._cfg)) diff --git a/test/test_wdl.py b/test/test_wdl.py index e46750b7..62c46f50 100644 --- a/test/test_wdl.py +++ b/test/test_wdl.py @@ -20,35 +20,42 @@ call add_world { input: - hello = hello, + input_file = hello, docker_image_id = docker_image_id } call add_goodbye { input: - hello_world = add_world.out, + input_file = add_world.out_world, + docker_image_id = docker_image_id + } + + call add_farewell { + input: + input_file = add_goodbye.out_goodbye, docker_image_id = docker_image_id } output { - File out = add_world.out + File out_world = add_world.out_world File out_goodbye = add_goodbye.out_goodbye + File out_farewell = add_farewell.out_farewell } } task add_world { input { - File hello + File input_file String docker_image_id } command <<< - cat ~{hello} > out.txt - echo world >> out.txt + cat ~{input_file} > out_world.txt + echo world >> out_world.txt >>> output { - File out = "out.txt" + File out_world = "out_world.txt" } runtime { @@ -58,12 +65,12 @@ task add_goodbye { input { - File hello_world + File input_file String docker_image_id } command <<< - cat ~{hello_world} > out_goodbye.txt + cat ~{input_file} > out_goodbye.txt echo goodbye >> out_goodbye.txt >>> @@ -75,6 +82,26 @@ docker: docker_image_id } } + +task add_farewell { + input { + File input_file + String docker_image_id + } + + command <<< + cat ~{input_file} > out_farewell.txt + echo farewell >> out_farewell.txt + >>> + + output { + File out_farewell = "out_farewell.txt" + } + + runtime { + docker: docker_image_id + } +} """ test_fail_wdl = """ @@ -161,7 +188,7 @@ test_stage_io_map = { "Two": { - "hello_world": "out", + "hello_world": "out_world", }, } @@ -301,11 +328,12 @@ def test_simple_sfn_wdl_workflow(self): output = json.loads(description["output"]) self.assertEqual(output["Result"], { - "swipe_test.out": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out.txt", + "swipe_test.out_world": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_world.txt", "swipe_test.out_goodbye": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_goodbye.txt", + "swipe_test.out_farewell": f"s3://{self.input_obj.bucket_name}/{output_prefix}/test-1/out_farewell.txt", }) - outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out.txt") + outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_world.txt") output_text = outputs_obj.get()["Body"].read().decode() self.assertEqual(output_text, "hello\nworld\n") @@ -384,19 +412,19 @@ def test_call_cache(self): self.sqs.receive_message( QueueUrl=self.state_change_queue_url, MaxNumberOfMessages=1 ) - outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out.txt") + outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_world.txt") output_text = outputs_obj.get()["Body"].read().decode() self.assertEqual(output_text, "hello\nworld\n") - self.test_bucket.Object(f"{output_prefix}/test-1/out.txt").put( + self.test_bucket.Object(f"{output_prefix}/test-1/out_goodbye.txt").put( Body="cache_break\n".encode() ) - self.test_bucket.Object(f"{output_prefix}/test-1/out_goodbye.txt").delete() + self.test_bucket.Object(f"{output_prefix}/test-1/out_farewell.txt").delete() # clear cache to simulate getting cut off the step before this one objects = self.s3_client.list_objects_v2( Bucket=self.test_bucket.name, - Prefix=f"{output_prefix}/test-1/cache/add_goodbye/", + Prefix=f"{output_prefix}/test-1/cache/add_farewell/", )["Contents"] self.test_bucket.Object(objects[0]["Key"]).delete() objects = self.s3_client.list_objects_v2( @@ -412,9 +440,9 @@ def test_call_cache(self): for v in outputs.values(): self.assert_(v.startswith("s3://"), f"{v} does not start with 's3://'") - outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_goodbye.txt") + outputs_obj = self.test_bucket.Object(f"{output_prefix}/test-1/out_farewell.txt") output_text = outputs_obj.get()["Body"].read().decode() - self.assertEqual(output_text, "cache_break\ngoodbye\n") + self.assertEqual(output_text, "cache_break\nfarewell\n") def test_zip_wdls(self): output_prefix = "zip-output" diff --git a/version b/version index 18fa8e74..23c38c24 100644 --- a/version +++ b/version @@ -1 +1 @@ -v1.3.0 +v1.3.1 \ No newline at end of file