Enhance segment extraction to support dataclass and dictionary formats in diarization output

This commit is contained in:
2025-11-30 21:54:02 +01:00
parent 4c1dd4f15e
commit f985e0e029

66
main.py
View File

@@ -19,6 +19,7 @@ console = Console()
from torch.serialization import add_safe_globals
from torch.torch_version import TorchVersion
from pyannote.audio.core.task import Specifications, Problem, Resolution
import dataclasses
# allowloading checkpoints that store these classes (PyTorch 2.6+ weights_only change)
add_safe_globals([TorchVersion, Specifications, Problem, Resolution])
@@ -117,30 +118,65 @@ with Progress(TextColumn("Diarization fut…"), BarColumn(), TimeElapsedColumn()
speaker_segments = []
speakers = set()
def _yield_from_annotation(annotation_obj):
if annotation_obj is None:
return
if hasattr(annotation_obj, "itertracks"):
for segment, _, speaker in annotation_obj.itertracks(yield_label=True):
yield speaker, float(segment.start), float(segment.end)
def _yield_from_track_dicts(track_list):
if not track_list:
return
for item in track_list:
if not isinstance(item, dict):
continue
speaker = item.get("speaker") or item.get("label") or "UNKNOWN"
start = float(item.get("start", 0.0))
end = float(item.get("end", start))
yield speaker, start, end
def extract_segments(diar_result):
# pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput
# pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput/dataclass
if hasattr(diar_result, "itertracks"):
for segment, _, speaker in diar_result.itertracks(yield_label=True):
yield speaker, float(segment.start), float(segment.end)
return
annotation = getattr(diar_result, "annotation", None)
if annotation is None and isinstance(diar_result, dict):
annotation = diar_result.get("annotation")
if annotation is not None and hasattr(annotation, "itertracks"):
for segment, _, speaker in annotation.itertracks(yield_label=True):
yield speaker, float(segment.start), float(segment.end)
# dataclass (e.g., DiarizeOutput) -> inspect fields
if dataclasses.is_dataclass(diar_result):
data_dict = dataclasses.asdict(diar_result)
yield from extract_segments(data_dict)
return
tracks = getattr(diar_result, "tracks", None)
if tracks is None and isinstance(diar_result, dict):
# dict-like
if isinstance(diar_result, dict):
annotation = diar_result.get("annotation")
if annotation is not None:
yield from _yield_from_annotation(annotation)
return
tracks = diar_result.get("tracks")
if tracks:
for item in tracks:
speaker = item.get("speaker") or item.get("label") or "UNKNOWN"
start = float(item.get("start", 0.0))
end = float(item.get("end", start))
yield speaker, start, end
yielded = False
for seg in _yield_from_track_dicts(tracks):
yielded = True
yield seg
if yielded:
return
# try any nested value that is annotation-like
for val in diar_result.values():
if hasattr(val, "itertracks"):
yield from _yield_from_annotation(val)
return
# as last resort, try list/tuple of dicts
if isinstance(diar_result, (list, tuple)):
for seg in _yield_from_track_dicts(diar_result):
yield seg
return
# generic list/tuple
if isinstance(diar_result, (list, tuple)):
for seg in _yield_from_track_dicts(diar_result):
yield seg
return
console.print(f"❌ Ismeretlen diarization output: {type(diar_result)}")