You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
286 lines
9.1 KiB
286 lines
9.1 KiB
"""semantic_gravity_examples.py — Show concrete motion examples for SVD axes across windows.
|
|
|
|
For each axis and window, finds motions closest to the semantic gravity vector,
|
|
providing concrete examples of what the axis "means" in that period.
|
|
|
|
Usage:
|
|
uv run python scripts/semantic_gravity_examples.py --db data/motions.db --axis 1
|
|
uv run python scripts/semantic_gravity_examples.py --db data/motions.db --all
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import argparse
|
|
import json
|
|
import os
|
|
import sys
|
|
from typing import Dict, List, Tuple
|
|
|
|
import duckdb
|
|
import numpy as np
|
|
|
|
|
|
def _load_fused_embeddings_with_titles(
|
|
con: duckdb.DuckDBPyConnection, window_id: str
|
|
) -> List[Tuple[int, np.ndarray, str]]:
|
|
"""Load fused embeddings with motion titles for a window."""
|
|
rows = con.execute(
|
|
"""
|
|
SELECT f.motion_id, f.vector, m.title
|
|
FROM fused_embeddings f
|
|
JOIN motions m ON f.motion_id = m.id
|
|
WHERE f.window_id = ?
|
|
""",
|
|
[window_id],
|
|
).fetchall()
|
|
|
|
result = []
|
|
for motion_id, raw_vec, title in rows:
|
|
if isinstance(raw_vec, str):
|
|
vec = json.loads(raw_vec)
|
|
elif isinstance(raw_vec, (bytes, bytearray)):
|
|
vec = json.loads(raw_vec.decode())
|
|
elif isinstance(raw_vec, list):
|
|
vec = raw_vec
|
|
else:
|
|
vec = list(raw_vec)
|
|
result.append(
|
|
(
|
|
motion_id,
|
|
np.array([float(v) if v is not None else 0.0 for v in vec]),
|
|
title or "",
|
|
)
|
|
)
|
|
return result
|
|
|
|
|
|
def _load_motion_scores(
|
|
con: duckdb.DuckDBPyConnection, window_id: str
|
|
) -> Dict[int, np.ndarray]:
|
|
"""Load SVD scores for a window. Returns {motion_id: score_array}."""
|
|
rows = con.execute(
|
|
"SELECT entity_id, vector FROM svd_vectors WHERE window_id = ? AND entity_type = 'motion'",
|
|
[window_id],
|
|
).fetchall()
|
|
|
|
result = {}
|
|
for entity_id, raw_vec in rows:
|
|
if isinstance(raw_vec, str):
|
|
vec = json.loads(raw_vec)
|
|
elif isinstance(raw_vec, (bytes, bytearray)):
|
|
vec = json.loads(raw_vec.decode())
|
|
elif isinstance(raw_vec, list):
|
|
vec = raw_vec
|
|
else:
|
|
vec = list(raw_vec)
|
|
result[int(entity_id)] = np.array(
|
|
[float(v) if v is not None else 0.0 for v in vec]
|
|
)
|
|
return result
|
|
|
|
|
|
def compute_semantic_gravity_examples(
|
|
con: duckdb.DuckDBPyConnection,
|
|
windows: List[str],
|
|
axis: int,
|
|
n_examples: int = 5,
|
|
n_components: int = 10,
|
|
) -> Dict:
|
|
"""Find motions closest to semantic gravity for an axis across windows."""
|
|
comp_idx = axis - 1
|
|
results = {}
|
|
|
|
for w in windows:
|
|
# Load data
|
|
motion_scores = _load_motion_scores(con, w)
|
|
embeddings_data = _load_fused_embeddings_with_titles(con, w)
|
|
|
|
if not motion_scores or not embeddings_data:
|
|
continue
|
|
|
|
# Build motion_id -> embedding mapping
|
|
embeddings_by_id = {mid: (vec, title) for mid, vec, title in embeddings_data}
|
|
|
|
# Find common motions
|
|
common = [m for m in motion_scores if m in embeddings_by_id]
|
|
if len(common) < 10:
|
|
continue
|
|
|
|
# Compute semantic gravity (weighted mean by absolute SVD score on this axis)
|
|
valid_embeddings = []
|
|
weights = []
|
|
for m_id in common:
|
|
scores = motion_scores[m_id]
|
|
if comp_idx < len(scores):
|
|
valid_embeddings.append(embeddings_by_id[m_id][0])
|
|
weights.append(abs(scores[comp_idx]))
|
|
|
|
if not valid_embeddings or sum(weights) == 0:
|
|
continue
|
|
|
|
# Align dimensions
|
|
dim = min(len(v) for v in valid_embeddings)
|
|
vectors = np.array([v[:dim] for v in valid_embeddings])
|
|
weights = np.array(weights[: len(vectors)])
|
|
gravity = np.average(vectors, axis=0, weights=weights)
|
|
|
|
# Find motions closest to gravity (highest cosine similarity)
|
|
similarities = []
|
|
for m_id in common:
|
|
vec, title = embeddings_by_id[m_id]
|
|
vec = vec[:dim]
|
|
norm_g = np.linalg.norm(gravity)
|
|
norm_v = np.linalg.norm(vec)
|
|
if norm_g > 0 and norm_v > 0:
|
|
sim = np.dot(gravity, vec) / (norm_g * norm_v)
|
|
similarities.append((sim, m_id, title))
|
|
|
|
# Sort by similarity and get top examples
|
|
similarities.sort(reverse=True)
|
|
top_positive = [s for s in similarities if s[0] > 0][:n_examples]
|
|
top_negative = [s for s in similarities if s[0] < 0][-n_examples:][::-1]
|
|
|
|
# Get extreme motions (highest absolute loading on this axis)
|
|
extreme = sorted(
|
|
common, key=lambda m: abs(motion_scores[m][comp_idx]), reverse=True
|
|
)[:n_examples]
|
|
extreme_motions = []
|
|
for m_id in extreme:
|
|
score = motion_scores[m_id][comp_idx]
|
|
title = embeddings_by_id.get(m_id, (None, ""))[1]
|
|
extreme_motions.append((score, m_id, title))
|
|
|
|
results[w] = {
|
|
"gravity": gravity,
|
|
"top_similar": top_positive,
|
|
"top_dissimilar": top_negative,
|
|
"extreme": extreme_motions,
|
|
}
|
|
|
|
return results
|
|
|
|
|
|
def _get_annual_windows(con: duckdb.DuckDBPyConnection) -> List[str]:
|
|
"""Get list of annual windows that have fused embeddings, sorted by year."""
|
|
rows = con.execute(
|
|
"""
|
|
SELECT DISTINCT f.window_id
|
|
FROM fused_embeddings f
|
|
JOIN svd_vectors s ON f.window_id = s.window_id AND s.entity_type = 'motion'
|
|
WHERE f.window_id NOT LIKE '%-Q%'
|
|
ORDER BY f.window_id
|
|
"""
|
|
).fetchall()
|
|
return [r[0] for r in rows]
|
|
|
|
|
|
def format_results(results: Dict, axis: int) -> str:
|
|
"""Format results as markdown."""
|
|
lines = [
|
|
f"# Semantic Gravity Examples for Axis {axis}",
|
|
"",
|
|
f"Shows motions closest to semantic gravity (weighted mean embedding) for each window.",
|
|
"This represents the 'typical' motion on this axis.",
|
|
"",
|
|
"---",
|
|
"",
|
|
]
|
|
|
|
for window in sorted(results.keys()):
|
|
data = results[window]
|
|
gravity = data["gravity"]
|
|
|
|
lines.append(f"## {window}")
|
|
lines.append("")
|
|
|
|
# Positive-pole extreme motions
|
|
lines.append("### Extreme Positive Motions (high positive loading)")
|
|
for score, m_id, title in data["extreme"]:
|
|
if score > 0:
|
|
lines.append(
|
|
f"- **[{score:+.3f}]** {title[:100]}{'...' if len(title) > 100 else ''}"
|
|
)
|
|
lines.append("")
|
|
|
|
# Negative-pole extreme motions
|
|
lines.append("### Extreme Negative Motions (high negative loading)")
|
|
for score, m_id, title in data["extreme"]:
|
|
if score < 0:
|
|
lines.append(
|
|
f"- **[{score:+.3f}]** {title[:100]}{'...' if len(title) > 100 else ''}"
|
|
)
|
|
lines.append("")
|
|
|
|
# Motions closest to semantic gravity
|
|
lines.append("### Most Representative Motions (closest to semantic gravity)")
|
|
for sim, m_id, title in data["top_similar"]:
|
|
lines.append(
|
|
f"- **[{sim:.3f}]** {title[:100]}{'...' if len(title) > 100 else ''}"
|
|
)
|
|
lines.append("")
|
|
|
|
return "\n".join(lines)
|
|
|
|
|
|
def main(argv: List[str] | None = None) -> int:
|
|
p = argparse.ArgumentParser(
|
|
description="Find semantic gravity examples for SVD axes"
|
|
)
|
|
p.add_argument("--db", default="data/motions.db", help="Path to motions database")
|
|
p.add_argument("--axis", type=int, default=1, help="SVD axis to analyze (1-10)")
|
|
p.add_argument(
|
|
"--windows", nargs="+", help="Specific windows (default: all annual windows)"
|
|
)
|
|
p.add_argument(
|
|
"--n-examples",
|
|
type=int,
|
|
default=5,
|
|
help="Number of example motions per category",
|
|
)
|
|
p.add_argument("--output", help="Output file (default: print to stdout)")
|
|
|
|
args = p.parse_args(argv)
|
|
|
|
if not os.path.exists(args.db):
|
|
print(f"Error: Database not found: {args.db}", file=sys.stderr)
|
|
return 1
|
|
|
|
con = duckdb.connect(database=args.db, read_only=True)
|
|
try:
|
|
# Determine windows
|
|
if args.windows:
|
|
windows = args.windows
|
|
else:
|
|
windows = _get_annual_windows(con)
|
|
print(f"Found {len(windows)} annual windows: {windows}", file=sys.stderr)
|
|
|
|
if len(windows) < 2:
|
|
print("Need at least 2 windows for analysis", file=sys.stderr)
|
|
return 1
|
|
|
|
# Run analysis
|
|
print(
|
|
f"Computing semantic gravity examples for Axis {args.axis}...",
|
|
file=sys.stderr,
|
|
)
|
|
results = compute_semantic_gravity_examples(
|
|
con, windows, args.axis, args.n_examples
|
|
)
|
|
|
|
# Format output
|
|
output = format_results(results, args.axis)
|
|
|
|
if args.output:
|
|
with open(args.output, "w") as f:
|
|
f.write(output)
|
|
print(f"Results written to {args.output}", file=sys.stderr)
|
|
else:
|
|
print(output)
|
|
|
|
return 0
|
|
finally:
|
|
con.close()
|
|
|
|
|
|
if __name__ == "__main__":
|
|
raise SystemExit(main())
|
|
|