Enhance segment extraction to support dataclass and dictionary formats in diarization output
This commit is contained in:
66
main.py
66
main.py
@@ -19,6 +19,7 @@ console = Console()
|
|||||||
from torch.serialization import add_safe_globals
|
from torch.serialization import add_safe_globals
|
||||||
from torch.torch_version import TorchVersion
|
from torch.torch_version import TorchVersion
|
||||||
from pyannote.audio.core.task import Specifications, Problem, Resolution
|
from pyannote.audio.core.task import Specifications, Problem, Resolution
|
||||||
|
import dataclasses
|
||||||
|
|
||||||
# allowloading checkpoints that store these classes (PyTorch 2.6+ weights_only change)
|
# allowloading checkpoints that store these classes (PyTorch 2.6+ weights_only change)
|
||||||
add_safe_globals([TorchVersion, Specifications, Problem, Resolution])
|
add_safe_globals([TorchVersion, Specifications, Problem, Resolution])
|
||||||
@@ -117,30 +118,65 @@ with Progress(TextColumn("Diarization fut…"), BarColumn(), TimeElapsedColumn()
|
|||||||
speaker_segments = []
|
speaker_segments = []
|
||||||
speakers = set()
|
speakers = set()
|
||||||
|
|
||||||
|
def _yield_from_annotation(annotation_obj):
|
||||||
|
if annotation_obj is None:
|
||||||
|
return
|
||||||
|
if hasattr(annotation_obj, "itertracks"):
|
||||||
|
for segment, _, speaker in annotation_obj.itertracks(yield_label=True):
|
||||||
|
yield speaker, float(segment.start), float(segment.end)
|
||||||
|
|
||||||
|
def _yield_from_track_dicts(track_list):
|
||||||
|
if not track_list:
|
||||||
|
return
|
||||||
|
for item in track_list:
|
||||||
|
if not isinstance(item, dict):
|
||||||
|
continue
|
||||||
|
speaker = item.get("speaker") or item.get("label") or "UNKNOWN"
|
||||||
|
start = float(item.get("start", 0.0))
|
||||||
|
end = float(item.get("end", start))
|
||||||
|
yield speaker, start, end
|
||||||
|
|
||||||
def extract_segments(diar_result):
|
def extract_segments(diar_result):
|
||||||
# pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput
|
# pyannote 3.x returns Annotation with itertracks; 4.x returns DiarizeOutput/dataclass
|
||||||
if hasattr(diar_result, "itertracks"):
|
if hasattr(diar_result, "itertracks"):
|
||||||
for segment, _, speaker in diar_result.itertracks(yield_label=True):
|
for segment, _, speaker in diar_result.itertracks(yield_label=True):
|
||||||
yield speaker, float(segment.start), float(segment.end)
|
yield speaker, float(segment.start), float(segment.end)
|
||||||
return
|
return
|
||||||
|
|
||||||
annotation = getattr(diar_result, "annotation", None)
|
# dataclass (e.g., DiarizeOutput) -> inspect fields
|
||||||
if annotation is None and isinstance(diar_result, dict):
|
if dataclasses.is_dataclass(diar_result):
|
||||||
annotation = diar_result.get("annotation")
|
data_dict = dataclasses.asdict(diar_result)
|
||||||
if annotation is not None and hasattr(annotation, "itertracks"):
|
yield from extract_segments(data_dict)
|
||||||
for segment, _, speaker in annotation.itertracks(yield_label=True):
|
|
||||||
yield speaker, float(segment.start), float(segment.end)
|
|
||||||
return
|
return
|
||||||
|
|
||||||
tracks = getattr(diar_result, "tracks", None)
|
# dict-like
|
||||||
if tracks is None and isinstance(diar_result, dict):
|
if isinstance(diar_result, dict):
|
||||||
|
annotation = diar_result.get("annotation")
|
||||||
|
if annotation is not None:
|
||||||
|
yield from _yield_from_annotation(annotation)
|
||||||
|
return
|
||||||
tracks = diar_result.get("tracks")
|
tracks = diar_result.get("tracks")
|
||||||
if tracks:
|
yielded = False
|
||||||
for item in tracks:
|
for seg in _yield_from_track_dicts(tracks):
|
||||||
speaker = item.get("speaker") or item.get("label") or "UNKNOWN"
|
yielded = True
|
||||||
start = float(item.get("start", 0.0))
|
yield seg
|
||||||
end = float(item.get("end", start))
|
if yielded:
|
||||||
yield speaker, start, end
|
return
|
||||||
|
# try any nested value that is annotation-like
|
||||||
|
for val in diar_result.values():
|
||||||
|
if hasattr(val, "itertracks"):
|
||||||
|
yield from _yield_from_annotation(val)
|
||||||
|
return
|
||||||
|
# as last resort, try list/tuple of dicts
|
||||||
|
if isinstance(diar_result, (list, tuple)):
|
||||||
|
for seg in _yield_from_track_dicts(diar_result):
|
||||||
|
yield seg
|
||||||
|
return
|
||||||
|
|
||||||
|
# generic list/tuple
|
||||||
|
if isinstance(diar_result, (list, tuple)):
|
||||||
|
for seg in _yield_from_track_dicts(diar_result):
|
||||||
|
yield seg
|
||||||
return
|
return
|
||||||
|
|
||||||
console.print(f"❌ Ismeretlen diarization output: {type(diar_result)}")
|
console.print(f"❌ Ismeretlen diarization output: {type(diar_result)}")
|
||||||
|
|||||||
Reference in New Issue
Block a user