Skip to content

Commit

Permalink
Fix block scalar mangling bug lyz-code#231
Browse files Browse the repository at this point in the history
The regex based parsing for fixing comments was breaking block scalars.
By using the ruyaml round trip handler, instead the comment formatting
now can correctly identify block-scalars and avoid mangling them.
  • Loading branch information
wrouesnel committed Apr 6, 2023
1 parent d75a141 commit 2073d3c
Show file tree
Hide file tree
Showing 3 changed files with 160 additions and 16 deletions.
107 changes: 91 additions & 16 deletions src/yamlfix/adapters.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,19 @@
"""Define adapter / helper classes to hide unrelated functionality in."""

import io
import logging
import re
from functools import partial
from io import StringIO
from typing import Any, Callable, List, Match, Optional, Tuple

import ruyaml
from ruyaml.main import YAML
from ruyaml.nodes import MappingNode, Node, ScalarNode, SequenceNode
from ruyaml.representer import RoundTripRepresenter
from ruyaml.tokens import CommentToken

from yamlfix.model import YamlfixConfig, YamlNodeStyle
from yamlfix.util import walk_object

log = logging.getLogger(__name__)

Expand Down Expand Up @@ -590,24 +592,97 @@ def _restore_truthy_strings(source_code: str) -> str:
def _fix_comments(self, source_code: str) -> str:
log.debug("Fixing comments...")
config = self.config
comment_start = " " * config.comments_min_spaces_from_content + "#"

fixed_source_lines = []
# We need the source lines for the comment fixers to analyze whitespace easily
source_lines = source_code.splitlines()

for line in source_code.splitlines():
# Comment at the start of the line
if config.comments_require_starting_space and re.search(r"(^|\s)#\w", line):
line = line.replace("#", "# ")
# Comment in the middle of the line, but it's not part of a string
if (
config.comments_min_spaces_from_content > 1
and " #" in line
and line[-1] not in ["'", '"']
):
line = re.sub(r"(.+\S)(\s+?)#", rf"\1{comment_start}", line)
fixed_source_lines.append(line)
# We need to parse the file with ruamel.yaml's roundtrip dumper to keep a track of whether any
# individual line is part of an appropriate block scalar.
yaml = ruyaml.YAML(typ="rt")
yaml_documents = list(yaml.load_all(source_code))

return "\n".join(fixed_source_lines)
def _comment_token_cb(o: Any, key: Optional[Any] = None):
if not isinstance(o, CommentToken):
return
if o.value is None:
return
comment_lines = o.value.split("\n")
fixed_comment_lines = []
for line in comment_lines:
if config.comments_require_starting_space and re.search(
r"(^|\s)#\w", line
):
line = line.replace("#", "# ")
fixed_comment_lines.append(line)

# Update the comment with the fixed lines
o.value = "\n".join(fixed_comment_lines)

if config.comments_min_spaces_from_content > 1:
# It's hard to reconstruct exactly where the content is, but since we have the line numbers
# what we do is lookup the literal source line here and check where the whitespace is compared
# to where we know the comment starts.
source_line = source_lines[o.start_mark.line]
content_part = source_line[0 : o.start_mark.column]
# Find the non-whitespace position in the content part
m = re.match(r"^.*\S$", content_part)
if (
m is not None
): # If no match then nothing to do - no content to be away from
content_start, content_end = m.span()
# If there's less than min-spaces from content, we're going to add some.
if (
o.start_mark.column - content_end
< config.comments_min_spaces_from_content
):
# Handled
o.start_mark.column = (
content_end + config.comments_min_spaces_from_content
)

def _comment_fixer(o: Any, key: Optional[Any] = None):
"""
This function is the callback for walk_object
walk_object calls it for every object it finds, and then will walk the mapping/sequence subvalues and
call this function on those too. This gives us direct access to all round tripped comments.
"""
if not hasattr(o, "ca"):
# Scalar or other object with no comment parameter.
return
# Find all comment tokens and fix them
walk_object(o.ca.comment, _comment_token_cb)
walk_object(o.ca.end, _comment_token_cb)
walk_object(o.ca.items, _comment_token_cb)
walk_object(o.ca.pre, _comment_token_cb)

# Walk the object and invoke the comment fixer
walk_object(yaml_documents, _comment_fixer)

# Dump out the YAML documents
stream = io.StringIO()
yaml.dump_all(yaml_documents, stream=stream)

# Scan the source lines for a leading "---" separator. If it's found, add it.
found_leading_document_separator = False
for line in source_lines:
stripped_line = line.strip()
if stripped_line.startswith("#"):
continue
if line.rstrip() == "---":
found_leading_document_separator = True
break
if stripped_line != "":
# Found non-comment content.
break

fixed_source_code = ""
if found_leading_document_separator:
fixed_source_code += "---\n"
fixed_source_code += stream.getvalue()

# Re-invoke the ruamel.yaml fixer to get the corrected re-encoding.
return self._ruamel_yaml_fixer(fixed_source_code)

def _fix_whitelines(self, source_code: str) -> str:
"""Fixes number of consecutive whitelines.
Expand Down
24 changes: 24 additions & 0 deletions src/yamlfix/util.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from typing import Any, Callable, Iterable, Mapping, Optional

from typing_extensions import Protocol


class ObjectCallback(Protocol):
def __call__(self, value: Any, key: Optional[Any] = None) -> None:
...


def walk_object(o: Any, fn: ObjectCallback):
"""Walk a YAML/JSON-like object and call a function on all values"""

# Call the callback and whatever we received.
fn(o)

if isinstance(o, Mapping):
# Map type
for key, value in o.items():
walk_object(value, fn)
elif isinstance(o, Iterable) and not isinstance(o, (bytes, str)):
# List type
for idx, value in enumerate(o):
walk_object(value, fn)
45 changes: 45 additions & 0 deletions tests/unit/test_adapter_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -830,3 +830,48 @@ def test_section_whitelines_begin_no_explicit_start(self) -> None:
result = fix_code(source, config)

assert result == fixed_source

def test_block_scalar_whitespace_is_preserved(self) -> None:
source = dedent(
"""\
---
addn_doc_key: |-
#######################################
# This would also be broken #
#######################################
---
#Comment above the key
key: |-
###########################################
# Value with lots of whitespace #
# Some More Whitespace #
###########################################
#Comment below
#Comment with some whitespace below
"""
)

fixed_source = dedent(
"""\
---
addn_doc_key: |-
#######################################
# This would also be broken #
#######################################
---
# Comment above the key
key: |-
###########################################
# Value with lots of whitespace #
# Some More Whitespace #
###########################################
# Comment below
# Comment with some whitespace below
"""
)

config = YamlfixConfig()
result = fix_code(source, config)
assert result == fixed_source

0 comments on commit 2073d3c

Please sign in to comment.