creative_studio/backend/app/utils/script_splitter.py

325 lines
10 KiB
Python
Raw Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

"""
剧本切分器
支持混合切分策略:
1. 优先按场景标记切分第X场、Scene X、【场景X】等
2. 单场超过阈值时按固定长度切分5000字符保留500字符重叠
"""
import re
from typing import List, Dict, Any
from dataclasses import dataclass
from app.utils.logger import get_logger
logger = get_logger(__name__)
@dataclass
class ScriptSegment:
"""剧本片段"""
index: int # 片段索引从0开始
content: str # 片段内容
start_pos: int # 在原文中的起始位置
end_pos: int # 在原文中的结束位置
scene_marker: str = "" # 场景标记(如果有)
metadata: Dict[str, Any] = None # 额外元数据
def __post_init__(self):
if self.metadata is None:
self.metadata = {}
class ScriptSplitter:
"""剧本切分器 - 混合切分策略"""
# 场景标记的正则表达式模式(支持多种格式)
SCENE_PATTERNS = [
r'第[一二三四五六七八九十百千0-9]+[集场幕]', # 第X场/集/幕
r'[第第]?[0-9]+[集场幕]', # 第1场/集/幕 或 1场/集/幕
r'Scene\s*\d+', # Scene 1, Scene 2, ...
r'【场景[^\n]*】', # 【场景xxx】
r'【[场景场][^\n]*】', # 【场xxx】
r'\[场景[^\n]*\]', # [场景xxx]
r'场景[:][^\n]*', # 场景xxx
r'[0-9]+\.[0-9]+\s*[场景场外]', # 1.1 场景/场/外
]
# 切分阈值
MAX_SEGMENT_LENGTH = 5000 # 单片段最大长度
OVERLAP_LENGTH = 500 # 切分时的重叠长度
CHUNK_THRESHOLD = 8000 # 超过此长度开始切分
def __init__(
self,
max_segment_length: int = MAX_SEGMENT_LENGTH,
overlap_length: int = OVERLAP_LENGTH,
chunk_threshold: int = CHUNK_THRESHOLD
):
"""
初始化切分器
Args:
max_segment_length: 单片段最大长度
overlap_length: 切分时的重叠长度(保持上下文)
chunk_threshold: 超过此长度开始切分
"""
self.max_segment_length = max_segment_length
self.overlap_length = overlap_length
self.chunk_threshold = chunk_threshold
def split(self, content: str) -> List[ScriptSegment]:
"""
切分剧本内容
Args:
content: 剧本内容
Returns:
切分后的片段列表
"""
if not content or len(content) <= self.chunk_threshold:
# 内容较短,不需要切分
return [ScriptSegment(
index=0,
content=content,
start_pos=0,
end_pos=len(content),
metadata={"split_method": "none", "reason": "content_too_short"}
)]
# 尝试按场景切分
segments = self._split_by_scenes(content)
# 检查是否有片段过长
oversized_segments = [
s for s in segments
if len(s.content) > self.max_segment_length
]
if oversized_segments:
# 对过长片段进行二次切分
segments = self._split_oversized_segments(segments)
logger.info(f"剧本切分完成:共 {len(segments)} 个片段")
return segments
def _split_by_scenes(self, content: str) -> List[ScriptSegment]:
"""
按场景标记切分
Args:
content: 剧本内容
Returns:
切分后的片段列表
"""
# 编译所有场景模式
patterns = [re.compile(pattern, re.MULTILINE) for pattern in self.SCENE_PATTERNS]
# 查找所有场景标记位置
scene_positions = []
for pattern in patterns:
for match in pattern.finditer(content):
scene_positions.append({
'pos': match.start(),
'marker': match.group(0),
'pattern': pattern.pattern
})
# 按位置排序并去重
scene_positions = sorted(
list({(p['pos'], p['marker']): p for p in scene_positions}.values()),
key=lambda x: x['pos']
)
if not scene_positions:
# 没有找到场景标记,使用固定长度切分
return self._split_by_length(content)
# 按场景位置切分
segments = []
for i, scene_info in enumerate(scene_positions):
start_pos = scene_info['pos']
end_pos = scene_positions[i + 1]['pos'] if i + 1 < len(scene_positions) else len(content)
segment_content = content[start_pos:end_pos].strip()
if segment_content:
segments.append(ScriptSegment(
index=i,
content=segment_content,
start_pos=start_pos,
end_pos=end_pos,
scene_marker=scene_info['marker'],
metadata={
"split_method": "scene",
"scene_marker": scene_info['marker']
}
))
return segments
def _split_by_length(self, content: str) -> List[ScriptSegment]:
"""
按固定长度切分(带重叠)
Args:
content: 剧本内容
Returns:
切分后的片段列表
"""
segments = []
content_length = len(content)
pos = 0
index = 0
while pos < content_length:
end_pos = min(pos + self.max_segment_length, content_length)
segment_content = content[pos:end_pos]
# 如果不是最后一段,添加重叠
if end_pos < content_length:
overlap_content = content[end_pos:end_pos + self.overlap_length]
segment_content += overlap_content
segments.append(ScriptSegment(
index=index,
content=segment_content,
start_pos=pos,
end_pos=end_pos,
metadata={
"split_method": "length",
"has_overlap": end_pos < content_length
}
))
# 移动到下一个片段
pos += self.max_segment_length
index += 1
return segments
def _split_oversized_segments(
self,
segments: List[ScriptSegment]
) -> List[ScriptSegment]:
"""
对过长的片段进行二次切分
Args:
segments: 原始片段列表
Returns:
处理后的片段列表
"""
result = []
for segment in segments:
if len(segment.content) <= self.max_segment_length:
# 片段长度合适,直接添加
result.append(segment)
else:
# 片段过长,进行二次切分
logger.info(f"片段 {segment.index} 过长 ({len(segment.content)} 字符),进行二次切分")
sub_segments = self._split_by_length(segment.content)
# 重新编号并保留原始元数据
for i, sub_seg in enumerate(sub_segments):
result.append(ScriptSegment(
index=len(result),
content=sub_seg.content,
start_pos=segment.start_pos + sub_seg.start_pos,
end_pos=segment.start_pos + sub_seg.end_pos,
scene_marker=segment.scene_marker if i == 0 else f"{segment.scene_marker}(续)",
metadata={
**segment.metadata,
"split_method": f"{segment.metadata.get('split_method', '')}_then_length",
"parent_scene": segment.scene_marker,
"sub_index": i
}
))
return result
def get_split_summary(self, segments: List[ScriptSegment]) -> Dict[str, Any]:
"""
获取切分摘要信息
Args:
segments: 片段列表
Returns:
摘要信息字典
"""
if not segments:
return {
"total_segments": 0,
"total_length": 0,
"split_methods": {}
}
total_length = sum(len(s.content) for s in segments)
split_methods = {}
for segment in segments:
method = segment.metadata.get("split_method", "unknown")
split_methods[method] = split_methods.get(method, 0) + 1
return {
"total_segments": len(segments),
"total_length": total_length,
"average_length": total_length // len(segments),
"min_length": min(len(s.content) for s in segments),
"max_length": max(len(s.content) for s in segments),
"split_methods": split_methods,
"has_scenes": any(s.scene_marker for s in segments)
}
def split_script(
content: str,
max_segment_length: int = ScriptSplitter.MAX_SEGMENT_LENGTH,
overlap_length: int = ScriptSplitter.OVERLAP_LENGTH,
chunk_threshold: int = ScriptSplitter.CHUNK_THRESHOLD
) -> Dict[str, Any]:
"""
便捷函数:切分剧本内容
Args:
content: 剧本内容
max_segment_length: 单片段最大长度
overlap_length: 切分时的重叠长度
chunk_threshold: 超过此长度开始切分
Returns:
包含片段列表和摘要的字典
"""
splitter = ScriptSplitter(
max_segment_length=max_segment_length,
overlap_length=overlap_length,
chunk_threshold=chunk_threshold
)
segments = splitter.split(content)
summary = splitter.get_split_summary(segments)
# 转换为可序列化的格式
segments_data = [
{
"index": s.index,
"content": s.content,
"start_pos": s.start_pos,
"end_pos": s.end_pos,
"scene_marker": s.scene_marker,
"length": len(s.content),
"metadata": s.metadata
}
for s in segments
]
return {
"segments": segments_data,
"summary": summary
}