Enhance segment extraction to support dataclass and dictionary formats in diarization output
This commit is contained in:
66
main.py
66
main.py
@@ -19,6 +19,7 @@ console = Console()
|
||||
from torch.serialization import add_safe_globals
|
||||
from torch.torch_version import TorchVersion
|
||||
from pyannote.audio.core.task import Specifications, Problem, Resolution
|
||||
import dataclasses
|
||||
|
||||
# allowloading checkpoints that store these classes (PyTorch 2.6+ weights_only change)
|
||||
add_safe_globals([TorchVersion, Specifications, Problem, Resolution])
|
||||
@@ -117,30 +118,65 @@ with Progress(TextColumn("Diarization fut…"), BarColumn(), TimeElapsedColumn()
|
||||
speaker_segments = []
|
||||
speakers = set()
|
||||
|
||||
def _yield_from_annotation(annotation_obj):
|
||||
if annotation_obj is None:
|
||||
return
|
||||
if hasattr(annotation_obj, "itertracks"):
|
||||
for segment, _, speaker in annotation_obj.itertracks(yield_label=True):
|
||||
yield speaker, float(segment.start), float(segment.end)
|
||||
|
||||
def _yield_from_track_dicts(track_list):
|
||||
if not track_list:
|
||||
return
|
||||
for item in track_list:
|
||||
if not isinstance(item, dict):
|
||||
continue
|
||||
speaker = item.get("speaker") or item.get("label") or "UNKNOWN"
|
||||
start = float(item.get("start", 0.0))
|
||||
end = float(item.get("end", start))
|
||||
yield speaker, start, end
|
||||
|
||||
def extract_segments(diar_result):
|
||||
# pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput
|
||||
# pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput/dataclass
|
||||
if hasattr(diar_result, "itertracks"):
|
||||
for segment, _, speaker in diar_result.itertracks(yield_label=True):
|
||||
yield speaker, float(segment.start), float(segment.end)
|
||||
return
|
||||
|
||||
annotation = getattr(diar_result, "annotation", None)
|
||||
if annotation is None and isinstance(diar_result, dict):
|
||||
annotation = diar_result.get("annotation")
|
||||
if annotation is not None and hasattr(annotation, "itertracks"):
|
||||
for segment, _, speaker in annotation.itertracks(yield_label=True):
|
||||
yield speaker, float(segment.start), float(segment.end)
|
||||
# dataclass (e.g., DiarizeOutput) -> inspect fields
|
||||
if dataclasses.is_dataclass(diar_result):
|
||||
data_dict = dataclasses.asdict(diar_result)
|
||||
yield from extract_segments(data_dict)
|
||||
return
|
||||
|
||||
tracks = getattr(diar_result, "tracks", None)
|
||||
if tracks is None and isinstance(diar_result, dict):
|
||||
# dict-like
|
||||
if isinstance(diar_result, dict):
|
||||
annotation = diar_result.get("annotation")
|
||||
if annotation is not None:
|
||||
yield from _yield_from_annotation(annotation)
|
||||
return
|
||||
tracks = diar_result.get("tracks")
|
||||
if tracks:
|
||||
for item in tracks:
|
||||
speaker = item.get("speaker") or item.get("label") or "UNKNOWN"
|
||||
start = float(item.get("start", 0.0))
|
||||
end = float(item.get("end", start))
|
||||
yield speaker, start, end
|
||||
yielded = False
|
||||
for seg in _yield_from_track_dicts(tracks):
|
||||
yielded = True
|
||||
yield seg
|
||||
if yielded:
|
||||
return
|
||||
# try any nested value that is annotation-like
|
||||
for val in diar_result.values():
|
||||
if hasattr(val, "itertracks"):
|
||||
yield from _yield_from_annotation(val)
|
||||
return
|
||||
# as last resort, try list/tuple of dicts
|
||||
if isinstance(diar_result, (list, tuple)):
|
||||
for seg in _yield_from_track_dicts(diar_result):
|
||||
yield seg
|
||||
return
|
||||
|
||||
# generic list/tuple
|
||||
if isinstance(diar_result, (list, tuple)):
|
||||
for seg in _yield_from_track_dicts(diar_result):
|
||||
yield seg
|
||||
return
|
||||
|
||||
console.print(f"❌ Ismeretlen diarization output: {type(diar_result)}")
|
||||
|
||||
Reference in New Issue
Block a user