#!/usr/bin/env python3

import argparse
import difflib
import os
import re
import shutil
import subprocess
import sys
import tempfile


VERSION = "clang-format-radare2 1.0"
SCRIPT_DIR = os.path.dirname(os.path.abspath(__file__))
PROJECT_ROOT = os.path.abspath(os.path.join(SCRIPT_DIR, os.pardir))
AUTO_WHITELIST_PATH = os.path.join(SCRIPT_DIR, "auto-format-files.txt")

CLANG_FORMAT_CONFIG = """BasedOnStyle: LLVM
Language: Cpp
PointerAlignment: Right
AlwaysBreakAfterDefinitionReturnType: None
BinPackParameters: false
BinPackArguments: false
MaxEmptyLinesToKeep: 1
SpaceInEmptyParentheses: false
SpacesInContainerLiterals: true
SpaceBeforeParens: Custom
SpaceBeforeParensOptions:
  AfterIfMacros: true
  AfterFunctionDefinitionName: false
  AfterFunctionDeclarationName: false
  AfterForeachMacros: true
  AfterControlStatements: true
  BeforeNonEmptyParentheses: false
SpacesInParentheses: false
InsertBraces: true
ContinuationIndentWidth: 8
IndentCaseLabels: false
IndentFunctionDeclarationAfterType: false
IndentWidth: 8
UseTab: ForContinuationAndIndentation
ColumnLimit: 0
BreakBeforeBraces: Attach
BreakBeforeTernaryOperators: true
SpaceBeforeSquareBrackets: false
SpaceInEmptyBlock: false
AllowShortIfStatementsOnASingleLine: false
AllowShortCaseLabelsOnASingleLine: true
AllowShortFunctionsOnASingleLine: Inline
AllowShortLoopsOnASingleLine: false
AlignAfterOpenBracket: DontAlign
AlignEscapedNewlines: DontAlign
AlignConsecutiveMacros: true
AlignTrailingComments: false
AlignOperands: false
Cpp11BracedListStyle: false
ForEachMacros: ['r_list_foreach', 'ls_foreach', 'fcn_tree_foreach_intersect', 'r_skiplist_foreach', 'graph_foreach_anode', 'r_list_foreach_safe', 'R_VEC_FOREACH', 'R_VEC_FOREACH_I', 'R_VEC_FOREACH_PREV', 'r_rbtree_foreach', 'r_interval_tree_foreach']
SortIncludes: false
"""


def parse_args():
	parser = argparse.ArgumentParser(
		description="Format files using clang-format followed by radare2 indentation rules."
	)
	parser.add_argument(
		"-i",
		"--in-place",
		dest="in_place",
		action="store_true",
		help="Do nothing, because that's the default behaviour, for muscle memory compatibility reasons with clang-format",
	)
	parser.add_argument(
		"-n",
		"--no-update",
		dest="no_update",
		action="store_true",
		help="Do not modify files; report differences and exit with error if formatting is needed.",
	)
	parser.add_argument(
		"-v",
		"--version",
		action="store_true",
		help="Show the clang-format-radare2 version and exit.",
	)
	parser.add_argument(
		"-a",
		"--auto",
		action="store_true",
		help="Auto mode: ignore provided paths and use the indentation whitelist.",
	)
	parser.add_argument(
		"files",
		nargs="*",
		help="Files or directories to format in place.",
	)
	parser.add_argument(
		"--clang-format",
		dest="clang_format",
		default=os.environ.get("CLANG_FORMAT", "clang-format"),
		help="clang-format executable to use (default: env CLANG_FORMAT or clang-format).",
	)
	parser.add_argument(
		"--print-config",
		action="store_true",
		help="Print the embedded .clang-format configuration to stdout and exit.",
	)
	return parser.parse_args()


def is_source_file(path):
	ext = os.path.splitext(path)[1].lower()
	return ext in (".c", ".h")


def expand_targets(paths):
	files = []
	seen = set()
	for target in paths:
		if not target or target in seen:
			continue
		if os.path.isdir(target):
			for root, _, filenames in os.walk(target):
				for name in filenames:
					if not is_source_file(name):
						continue
					full = os.path.join(root, name)
					if full in seen:
						continue
					files.append(full)
					seen.add(full)
		else:
			files.append(target)
			seen.add(target)
	return files


def load_auto_files():
	paths = []
	try:
		with open(AUTO_WHITELIST_PATH, "r", encoding="utf-8") as fd:
			for line in fd:
				stripped = line.strip()
				if not stripped or stripped.startswith("#"):
					continue
				paths.append(stripped)
	except OSError:
		return []
	result = []
	seen = set()
	for entry in paths:
		path = entry if os.path.isabs(entry) else os.path.join(PROJECT_ROOT, entry)
		if path in seen or not os.path.isfile(path):
			continue
		result.append(path)
		seen.add(path)
	return result


def show_diff(path, original, formatted):
	try:
		display_path = os.path.relpath(path)
	except ValueError:
		display_path = path
	for line in difflib.unified_diff(
		original.splitlines(True),
		formatted.splitlines(True),
		fromfile=display_path,
		tofile=f"{display_path} (formatted)",
	):
		sys.stdout.write(line)


def convert_leading_spaces(line):
	match = re.match(r"^([ \t]+)", line)
	if not match:
		return line
	ws = match.group(1)
	width = 0
	for ch in ws:
		if ch == ' ':
			width += 1
		elif ch == '\t':
			width += 8
	tabs = "\t" * (width // 8)
	return tabs + line[len(ws):]


def fix_paren_spacing(line):
	line_starts_alnum = bool(line and re.match(r"[A-Za-z0-9]", line))
	in_single = False
	in_double = False
	escape = False
	result = []
	length = len(line)
	i = 0
	while i < length:
		ch = line[i]
		if escape:
			result.append(ch)
			escape = False
			i += 1
			continue
		if ch == "\\" and (in_single or in_double):
			result.append(ch)
			escape = True
			i += 1
			continue
		if ch == "'" and not in_double:
			in_single = not in_single
			result.append(ch)
			i += 1
			continue
		if ch == '"' and not in_single:
			in_double = not in_double
			result.append(ch)
			i += 1
			continue
		if not in_single and not in_double and ch == "(":
			while result and result[-1] == " ":
				result.pop()
			prev = result[-1] if result else ""
			if not line_starts_alnum and prev not in ("", "(", "\t", "\n") and prev != "*" and prev != "_" and prev != "&" and prev != "[" and not line.startswith("#"):
				result.append(" ")
			result.append("(")
			i += 1
			continue
		result.append(ch)
		i += 1
	return "".join(result)


def split_line_ending(line):
	if line.endswith("\r\n"):
		return line[:-2], "\r\n"
	if line.endswith("\n"):
		return line[:-1], "\n"
	return line, ""


CASE_OPEN_RE = re.compile(r"^(\s*(?:case\b.*:|default:))\s*\{(\s*//.*)?$")
CASE_CLOSE_BREAK_RE = re.compile(r"^(\s*)\}\s*break;(.*)$")
CASE_CLOSE_ONLY_RE = re.compile(r"^(\s*)\}(.*)$")
LABEL_RE = re.compile(r"^(\s*)([A-Za-z_][A-Za-z0-9_]*):(\s*(?://.*)?)?$")


def fix_case_blocks(lines):
	fixed = []
	case_indent_stack = []
	for line in lines:
		content, ending = split_line_ending(line)
		line_sep = ending or "\n"
		match_open = CASE_OPEN_RE.match(content)
		if match_open:
			case_line = match_open.group(1).rstrip()
			comment = match_open.group(2)
			if comment:
				case_line = f"{case_line} {comment.strip()}"
			indent_match = re.match(r"^(\s*)", content)
			indent = indent_match.group(1) if indent_match else ""
			fixed.append(case_line + line_sep)
			fixed.append(indent + "\t{" + line_sep)
			case_indent_stack.append((indent, indent + "\t", indent + "\t\t"))
			continue
		match_close = CASE_CLOSE_BREAK_RE.match(content)
		if match_close:
			indent = match_close.group(1)
			comment = match_close.group(2).strip()
			if case_indent_stack:
				_, inner_indent, body_indent = case_indent_stack.pop()
			else:
				inner_indent = indent
			fixed.append(inner_indent + "}" + line_sep)
			if comment:
				fixed.append(inner_indent + "break; " + comment + line_sep)
			else:
				fixed.append(inner_indent + "break;" + line_sep)
			continue
		match_close_only = CASE_CLOSE_ONLY_RE.match(content)
		if match_close_only and case_indent_stack:
			indent = match_close_only.group(1)
			case_indent, inner_indent, body_indent = case_indent_stack[-1]
			if indent == case_indent:
				case_indent_stack.pop()
				remainder = match_close_only.group(2).strip()
				if remainder:
					fixed.append(inner_indent + "}" + " " + remainder + line_sep)
				else:
					fixed.append(inner_indent + "}" + line_sep)
				continue
		if case_indent_stack:
			_, inner_indent, body_indent = case_indent_stack[-1]
			if content and not content.lstrip().startswith("{") and not ("||" in content or "&&" in content) and not content.lstrip().startswith("?") and not content.lstrip().startswith(":"):
				if content.startswith(inner_indent):
					adjusted = body_indent + content[len(inner_indent):]
					fixed.append(adjusted + line_sep)
					continue
		fixed.append(line)
	return fixed


def fix_labels(lines):
	fixed = []
	for line in lines:
		content, ending = split_line_ending(line)
		line_sep = ending or "\n"
		match = LABEL_RE.match(content)
		if match:
			label_name = match.group(2)
			# Skip case and default labels - they're switch labels, not goto labels
			if label_name in ("case", "default"):
				fixed.append(line)
				continue
			comment = match.group(3) or ""
			fixed.append(label_name + ":" + comment + line_sep)
			continue
		fixed.append(line)
	return fixed


def fix_ternary_spacing(line):
	in_single = False
	in_double = False
	escape = False
	question_depth = 0
	result = []
	length = len(line)
	i = 0
	while i < length:
		ch = line[i]
		if escape:
			result.append(ch)
			escape = False
			i += 1
			continue
		if ch == "\\" and (in_single or in_double):
			result.append(ch)
			escape = True
			i += 1
			continue
		if ch == "'" and not in_double:
			in_single = not in_single
			result.append(ch)
			i += 1
			continue
		if ch == '"' and not in_single:
			in_double = not in_double
			result.append(ch)
			i += 1
			continue
		if not in_single and not in_double:
			if ch == "?":
				question_depth += 1
				while result and result[-1] == " ":
					result.pop()
				result.append("?")
				i += 1
				while i < length and line[i] in " \t":
					i += 1
				if i < length and line[i] not in (" ", "\t", "\n", ")", ",", ";", ":"):
					result.append(" ")
				continue
			if ch == ":" and question_depth > 0:
				question_depth -= 1
				while result and result[-1] == " ":
					result.pop()
				result.append(":")
				i += 1
				while i < length and line[i] in " \t":
					i += 1
				if i < length and line[i] not in (" ", "\t", "\n", ")", ",", ";"):
					result.append(" ")
				continue
		result.append(ch)
		i += 1
	return "".join(result)


def fix_hash_spacing(line):
	in_single = False
	in_double = False
	escape = False
	result = []
	length = len(line)
	i = 0
	while i < length:
		ch = line[i]
		if escape:
			result.append(ch)
			escape = False
			i += 1
			continue
		if ch == "\\" and (in_single or in_double):
			result.append(ch)
			escape = True
			i += 1
			continue
		if ch == "'" and not in_double:
			in_single = not in_single
			result.append(ch)
			i += 1
			continue
		if ch == '"' and not in_single:
			in_double = not in_double
			result.append(ch)
			i += 1
			continue
		if not in_single and not in_double and ch == '#' and i + 1 < length and line[i+1] == '#':
			# Found ##, replace surrounding spaces
			# Remove trailing spaces before ##
			while result and result[-1] == ' ':
				result.pop()
			result.append(' ')
			result.append('##')
			result.append(' ')
			i += 2
			# Skip leading spaces after ##
			while i < length and line[i] == ' ':
				i += 1
			continue
		result.append(ch)
		i += 1
	return "".join(result)


def fix_logical_spacing(line):
	in_single = False
	in_double = False
	escape = False
	result = []
	length = len(line)
	i = 0
	while i < length:
		ch = line[i]
		if escape:
			result.append(ch)
			escape = False
			i += 1
			continue
		if ch == "\\" and (in_single or in_double):
			result.append(ch)
			escape = True
			i += 1
			continue
		if ch == "'" and not in_double:
			in_single = not in_single
			result.append(ch)
			i += 1
			continue
		if ch == '"' and not in_single:
			in_double = not in_double
			result.append(ch)
			i += 1
			continue
		if not in_single and not in_double:
			if ch == '&' and i + 1 < length and line[i+1] == '&':
				result.append("&&")
				i += 2
				if i < length and line[i] not in (' ', '\t', '\n'):
					result.append(" ")
				continue
			elif ch == '|' and i + 1 < length and line[i+1] == '|':
				result.append("||")
				i += 2
				if i < length and line[i] not in (' ', '\t', '\n'):
					result.append(" ")
				continue
		result.append(ch)
		i += 1
	return "".join(result)


def format_macro_body(body):
	if not body:
		return body
	formatted = fix_paren_spacing(body)
	formatted = fix_ternary_spacing(formatted)
	formatted = fix_hash_spacing(formatted)
	formatted = fix_logical_spacing(formatted)
	return formatted


def normalize_define_line(line):
	rest = line[len("#define"):].lstrip()
	if not rest:
		return "#define"
	i = 0
	while i < len(rest) and (rest[i].isalnum() or rest[i] == '_'):
		i += 1
	name = rest[:i]
	if not name:
		return line
	j = i
	while j < len(rest) and rest[j] == ' ':
		j += 1
	header = f"#define {name}"
	body = ""
	rest_after_name = rest[i:]
	if rest_after_name.startswith('('):
		params_start = i
		depth = 0
		k = params_start
		while k < len(rest):
			ch = rest[k]
			if ch == '(':
				depth += 1
			elif ch == ')':
				depth -= 1
				if depth == 0:
					k += 1
					break
			k += 1
		if depth != 0:
			return line
		params = rest[params_start:k]
		header = f"{header}{params}"
		body = rest[k:].lstrip()
	else:
		while j < len(rest) and rest[j] in (' ', '\t'):
			j += 1
		body = rest[j:].lstrip()
	if not body:
		return header
	body = format_macro_body(body)
	return f"{header} {body}"


def fix_preprocessor_line(line):
	content, ending = split_line_ending(line)
	stripped = content.lstrip()
	if not stripped.startswith('#'):
		return line
	adjusted = stripped.replace('\t', ' ')
	if adjusted.startswith('#define'):
		adjusted = normalize_define_line(adjusted)
	else:
		adjusted = fix_hash_spacing(adjusted)
	adjusted = fix_hash_spacing(adjusted)
	return adjusted + ending


def fix_multiline_comments(lines):
	result = []
	in_comment = False
	for line in lines:
		content, ending = split_line_ending(line)
		stripped = content.strip()
		if '/*' in stripped and not in_comment:
			in_comment = True
		if in_comment and re.match(r'^\s*\*', content):
			# Ensure exactly one space before *
			match = re.match(r'^(\s*)\*(.*)$', content)
			if match:
				indent = match.group(1)
				rest = match.group(2)
				if indent and indent[-1] != ' ':
					indent += ' '
				elif not indent:
					indent = ' '
				content = indent + '*' + rest
		if '*/' in stripped and in_comment:
			in_comment = False
		result.append(content + ending)
	return result


def fix_ternary_lines(lines):
	fixed = []
	i = 0
	while i < len(lines):
		line = lines[i]
		content, ending = split_line_ending(line)
		# Check if line has ? and the ? is not followed by : in the same line
		if '?' in content and not re.search(r'\?\s*[^:]*$', content):
			# If next line starts with spaces and :
			if i + 1 < len(lines):
				next_line = lines[i + 1]
				next_content, next_ending = split_line_ending(next_line)
				if next_content.strip().startswith(':'):
					# Find the last ? in content
					last_q = content.rfind('?')
					if last_q != -1:
						before_q = content[:last_q]
						after_q = content[last_q:]
						# The after_q is ? ... 
						# The next_content is spaces : ...
						# Get the indent of next_content
						indent_match = re.match(r'^(\s*)', next_content)
						indent = indent_match.group(1) if indent_match else ''
						# New line 1: before_q + next_content.strip()
						new_line1 = before_q + next_content.strip() + ending
						# New line 2: indent + after_q + next_ending
						new_line2 = indent + after_q + next_ending
						fixed.append(new_line1)
						fixed.append(new_line2)
						i += 2
						continue
		fixed.append(line)
		i += 1
	return fixed


def apply_indent_rules(text):
	lines = []
	for line in text.splitlines(True):
		if line.lstrip().startswith('#'):
			lines.append(fix_preprocessor_line(line))
			continue
		line = convert_leading_spaces(line)
		line = fix_paren_spacing(line)
		line = fix_ternary_spacing(line)
		line = fix_hash_spacing(line)
		line = fix_logical_spacing(line)
		lines.append(line)
	lines = fix_case_blocks(lines)
	lines = fix_labels(lines)
	lines = fix_ternary_lines(lines)
	lines = fix_multiline_comments(lines)
	return "".join(lines)


def format_file(path, clang_format, style_file, check_only=False):
	if not os.path.isfile(path):
		raise FileNotFoundError(f"{path}: no such file")
	with open(path, "r", encoding="utf-8") as current:
		original = current.read()
	try:
		result = subprocess.run(
			[clang_format, "-style=file:" + style_file, path],
			check=True,
			stdout=subprocess.PIPE,
			stderr=subprocess.PIPE,
			text=True,
		)
	except subprocess.CalledProcessError as exc:
		raise RuntimeError(
			f"clang-format failed for {path}: {exc.stderr.strip() or exc}"
		) from exc
	indented = apply_indent_rules(result.stdout)
	if original == indented:
		return False
	if check_only:
		show_diff(path, original, indented)
		return True
	with tempfile.NamedTemporaryFile("w", encoding="utf-8", delete=False) as tmp:
		tmp.write(indented)
		temp_name = tmp.name
	os.replace(temp_name, path)
	return True


def main():
	args = parse_args()
	if args.print_config:
		print(CLANG_FORMAT_CONFIG)
		return 0
	if args.version:
		print(VERSION)
		return 0
	temp_dir = tempfile.gettempdir()
	style_file = os.path.join(temp_dir, "clang-format-radare2.tmp")
	try:
		with open(style_file, "w", encoding="utf-8") as f:
			f.write(CLANG_FORMAT_CONFIG)
		paths = args.files
		if args.auto:
			paths = load_auto_files()
			if not paths:
				print(
					"clang-format-radare2: auto mode whitelist is empty or missing",
					file=sys.stderr,
				)
				return 1
		if not paths:
			print("clang-format-radare2: no input files", file=sys.stderr)
			return 1
		clang_format = args.clang_format
		if not shutil.which(clang_format):
			print(
				f"clang-format-radare2: cannot find clang-format ({clang_format})",
				file=sys.stderr,
			)
			return 1
		files = expand_targets(paths)
		if not files:
			print("clang-format-radare2: no input files", file=sys.stderr)
			return 1
		exit_code = 0
		for path in files:
			try:
				changed = format_file(path, clang_format, style_file, args.no_update)
				if args.no_update and changed:
					exit_code = 1
			except (FileNotFoundError, RuntimeError, OSError) as err:
				print(f"clang-format-radare2: {err}", file=sys.stderr)
				exit_code = 1
	finally:
		try:
			os.remove(style_file)
		except OSError:
			pass
	return exit_code


if __name__ == "__main__":
	sys.exit(main())
