diff --git a/main.py b/main.py index 470d8a4..9a34ca5 100644 --- a/main.py +++ b/main.py @@ -19,6 +19,7 @@ console = Console() from torch.serialization import add_safe_globals from torch.torch_version import TorchVersion from pyannote.audio.core.task import Specifications, Problem, Resolution +import dataclasses # allowloading checkpoints that store these classes (PyTorch 2.6+ weights_only change) add_safe_globals([TorchVersion, Specifications, Problem, Resolution]) @@ -117,30 +118,65 @@ with Progress(TextColumn("Diarization fut…"), BarColumn(), TimeElapsedColumn() speaker_segments = [] speakers = set() +def _yield_from_annotation(annotation_obj): + if annotation_obj is None: + return + if hasattr(annotation_obj, "itertracks"): + for segment, _, speaker in annotation_obj.itertracks(yield_label=True): + yield speaker, float(segment.start), float(segment.end) + +def _yield_from_track_dicts(track_list): + if not track_list: + return + for item in track_list: + if not isinstance(item, dict): + continue + speaker = item.get("speaker") or item.get("label") or "UNKNOWN" + start = float(item.get("start", 0.0)) + end = float(item.get("end", start)) + yield speaker, start, end + def extract_segments(diar_result): - # pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput + # pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput/dataclass if hasattr(diar_result, "itertracks"): for segment, _, speaker in diar_result.itertracks(yield_label=True): yield speaker, float(segment.start), float(segment.end) return - annotation = getattr(diar_result, "annotation", None) - if annotation is None and isinstance(diar_result, dict): - annotation = diar_result.get("annotation") - if annotation is not None and hasattr(annotation, "itertracks"): - for segment, _, speaker in annotation.itertracks(yield_label=True): - yield speaker, float(segment.start), float(segment.end) + # dataclass (e.g., DiarizeOutput) -> inspect fields + if dataclasses.is_dataclass(diar_result): + data_dict = dataclasses.asdict(diar_result) + yield from extract_segments(data_dict) return - tracks = getattr(diar_result, "tracks", None) - if tracks is None and isinstance(diar_result, dict): + # dict-like + if isinstance(diar_result, dict): + annotation = diar_result.get("annotation") + if annotation is not None: + yield from _yield_from_annotation(annotation) + return tracks = diar_result.get("tracks") - if tracks: - for item in tracks: - speaker = item.get("speaker") or item.get("label") or "UNKNOWN" - start = float(item.get("start", 0.0)) - end = float(item.get("end", start)) - yield speaker, start, end + yielded = False + for seg in _yield_from_track_dicts(tracks): + yielded = True + yield seg + if yielded: + return + # try any nested value that is annotation-like + for val in diar_result.values(): + if hasattr(val, "itertracks"): + yield from _yield_from_annotation(val) + return + # as last resort, try list/tuple of dicts + if isinstance(diar_result, (list, tuple)): + for seg in _yield_from_track_dicts(diar_result): + yield seg + return + + # generic list/tuple + if isinstance(diar_result, (list, tuple)): + for seg in _yield_from_track_dicts(diar_result): + yield seg return console.print(f"❌ Ismeretlen diarization output: {type(diar_result)}")