import json

# --- INPUT FILES ---
pnp_calib_file = "pnp_calibrations_filtered.json"
cam_serials_file = "droid/camera_serials.json"                    # or "droid/camera_serials.json"
full_episode_map_file = "droid/full_episode_id_to_path.json"      # or "droid/full_episode_id_to_path.json"
output_file = "pnp_calibrations_converted.json"

# ---------- Helpers ----------

def extract_rel_path(full_path: str) -> str:
    """Extract LAB/STATUS/DATE/TIMESTAMP from RLDS calibration key."""
    parts = full_path.split("/")
    idx = parts.index("r2d2-data-full")
    lab, status, date, timestamp = parts[idx+1:idx+5]
    return f"{lab}/{status}/{date}/{timestamp}"

def normalize(rel_path: str) -> str:
    head, tail = rel_path.rsplit("/", 1)
    return f"{head}/{tail.replace(':','_')}"

def denormalize(rel_path: str) -> str:
    head, tail = rel_path.rsplit("/", 1)
    return f"{head}/{tail.replace('_',':')}"

def build_episode_lookup(full_episode_map):
    """Build rel_path -> episode_id mapping from r2d2 and gresearch."""
    rel2ep, ep2rel = {}, {}
    for ep_id, srcs in full_episode_map.items():
        for key in ("r2d2", "gresearch"):
            rel = srcs.get(key)
            if not rel:   # skip null
                continue
            rel_norm = normalize(rel)
            rel2ep[rel_norm] = ep_id
            # prefer r2d2 over gresearch if both exist
            if ep_id not in ep2rel or key == "r2d2":
                ep2rel[ep_id] = rel
    return rel2ep, ep2rel

def get_serial(v: dict, which: str) -> str:
    candidates = {
        "ext1": ["ext1_cam_serial","ext_1_cam_serial","ext_cam_serial_1","ext1"],
        "ext2": ["ext2_cam_serial","ext_2_cam_serial","ext_cam_serial_2","ext2"],
    }[which]
    for k in candidates:
        if k in v and v[k]:
            return str(v[k])
    return None

# ---------- Load ----------
with open(pnp_calib_file, "r") as f:
    pnp_data = json.load(f)

with open(cam_serials_file, "r") as f:
    cam_serials = json.load(f)

with open(full_episode_map_file, "r") as f:
    full_episode_map = json.load(f)

rel2ep, ep2rel = build_episode_lookup(full_episode_map)

# ---------- Convert ----------
result = {}
skipped_bvl = skipped_no_ep = 0

for full_path, cams in pnp_data.items():
    rel_path = extract_rel_path(full_path)
    lab = rel_path.split("/",1)[0]
    if lab == "BVL":   # ignore BVL completely
        skipped_bvl += 1
        continue

    rel_norm = normalize(rel_path)
    rel_denorm = denormalize(rel_path)

    ep_id = rel2ep.get(rel_norm) or rel2ep.get(rel_denorm)
    if not ep_id:
        skipped_no_ep += 1
        continue

    # Always include the episode
    rel_out = ep2rel[ep_id]
    entry = {"relative_path": rel_out}

    # Try to attach extrinsics if possible
    serials = cam_serials.get(ep_id, {})
    ext1_serial = get_serial(serials,"ext1")
    ext2_serial = get_serial(serials,"ext2")

    e1 = cams.get("exterior_image_1_left",{}).get("extrinsics")
    e2 = cams.get("exterior_image_2_left",{}).get("extrinsics")

    if e1 and ext1_serial:
        entry[ext1_serial] = e1
    if e2 and ext2_serial:
        entry[ext2_serial] = e2

    # Store in result even if no serials matched
    result[ep_id] = entry

# ---------- Save ----------
with open(output_file,"w") as f:
    json.dump(result,f,indent=4)

print(f"✅ Converted JSON written to {output_file}")
print(f"   Skipped BVL: {skipped_bvl}")
print(f"   Skipped no episode mapping: {skipped_no_ep}")
print(f"   Final entries: {len(result)}")