Briswell Tech Blog

ブリスウェルのテックブログです

スポーツ動画のタグ付け

そろそろ年末が近づいてきました。気持ちの良い秋晴れ(もう冬ですかね)の空が広がっています。先日、数年ぶりに体育館で運動をしました。普段PCと向き合ってガチガチの身体がほぐれて良かったです。

今回は久しぶりにAI関連の記事です。CLIPモデルを利用して動画を解析してみます。

CLIPはOpenAIによって開発されたモデルで、画像とその説明(テキスト)の関係を検出します。このモデルは、インターネットから集めた大量の画像とテキストのペアで学習しています。特定のタスク用に追加の学習を必要とせず、多様なシーンで精度を出せるのが魅力ですね。

体育館での運動の合間の一コマです。謎の動きをしていますが、はたしてCLIPモデルを何をしているか理解できるでしょうか。

1. 検出する動き(テキスト)を日本語・英語で定義

{
  "投げる": "throw",
  "歩く": "walk",
  "走る": "run",
  "飛ぶ": "jump",
  "泳ぐ": "swim",
  "踊る": "dance",
  "歌う": "sing",
  "座る": "sit",
  "描く": "draw",
  "寝る": "sleep"
}

2. テキストの特徴量をpickleファイルへ保存

import torch
import clip
import pickle
import json

# CLIPモデルの初期化
device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/32", device=device)

# 事前に準備した日本語と英語の辞書
with open('japanese_to_english_dict.json', 'r', encoding='utf-8') as file:
    japanese_to_english_dict = json.load(file)

# 英語に翻訳されたタグをCLIPモデル用にトークナイズ
translated_tags = list(japanese_to_english_dict.values())
text = clip.tokenize(translated_tags).to(device)

# テキストの特徴量を計算
with torch.no_grad():
    text_features = model.encode_text(text)

# pickleファイルとして保存
with open('text_features.pkl', 'wb') as f:
    pickle.dump(text_features, f)

print("テキスト特徴量を保存しました。")

3. 動画を読み込んで各フレームにタグ付け

import cv2
import torch
import clip
import pickle
from PIL import Image
from collections import Counter
import json
import os
import glob

# 指定されたフォルダ内の画像を削除する関数
def delete_images_in_folder(folder, file_extension="*.jpg"):
    files = glob.glob(os.path.join(folder, file_extension))
    for f in files:
        os.remove(f)

# フレームにテキストを描画する関数
def draw_text_on_frame(frame, text, position, font=cv2.FONT_HERSHEY_SIMPLEX, 
                       font_scale=0.7, font_color=(0, 255, 0), line_type=2):
    cv2.putText(frame, text, position, font, font_scale, font_color, line_type)

# バッチごとにフレームを処理する関数
def process_batch(frame_batch, start_frame_index, model, transform, text_features, 
                  japanese_tags, japanese_to_english_dict, all_tags_for_video, 
                  output_folder, fps):
    # バッチ内の各フレームをRGBに変換
    batch_rgb = [cv2.cvtColor(frame, cv2.COLOR_BGR2RGB) for frame in frame_batch]

    # 変換されたフレームをPyTorchテンソルに変換
    batch_transformed = torch.stack([transform(Image.fromarray(img)) for img in batch_rgb]).to(device)

    # CLIPモデルを使用して画像の特徴量を抽出
    with torch.no_grad():
        image_features = model.encode_image(batch_transformed)
        logits_per_image = (image_features @ text_features.T)
        probs = logits_per_image.softmax(dim=1)
        top_tag_indices_list = probs.topk(N).indices

    # 各フレームごとに最も関連性の高いタグを選択
    for i, top_tag_indices in enumerate(top_tag_indices_list):
        valid_indices = [idx for idx in top_tag_indices if probs[i][idx] > SIMILARITY_THRESHOLD]
        top_tags_for_frame = [(japanese_tags[idx], probs[i][idx].item()) for idx in valid_indices]

        # 日本語タグを英語に変換
        top_tags_for_frame_english = [(japanese_to_english_dict[tag], score) for tag, score in top_tags_for_frame]

        # 現在のフレームのインデックスと時間を計算
        current_frame_index = start_frame_index + i
        current_frame_time = current_frame_index / fps

        # 処理中のフレームとそのタグをコンソールに出力
        print(f"Frame {current_frame_index} (Time: {current_frame_time:.2f} seconds): {top_tags_for_frame_english}")

        # フレームにタグを描画して保存
        for j, (eng_tag, score) in enumerate(top_tags_for_frame_english):
            text = f"{eng_tag}: {score:.2f}"
            draw_text_on_frame(frame_batch[i], text, (10, 30 + j*30))

        frame_filename = f"{output_folder}/frame_{current_frame_index}.jpg"
        cv2.imwrite(frame_filename, frame_batch[i])

        # 抽出されたタグを全タグのリストに追加
        for tag, _ in top_tags_for_frame:
            all_tags_for_video.append(tag)

# メインスクリプトの開始
device = "cuda" if torch.cuda.is_available() else "cpu"
model, transform = clip.load("ViT-B/32", device=device)

# pickleファイルからテキスト特徴量を読み込み
with open('text_features.pkl', 'rb') as f:
    text_features = pickle.load(f).to(device)

# 動画ファイルを読み込み
cap = cv2.VideoCapture('sports-movie.mp4')
fps = int(cap.get(cv2.CAP_PROP_FPS))

# 日本語と英語の辞書を読み込み
with open('japanese_to_english_dict.json', 'r', encoding='utf-8') as file:
    japanese_to_english_dict = json.load(file)

# 日本語のタグリストを作成
japanese_tags = list(japanese_to_english_dict.keys())

# 処理された全フレームのタグを保存するリストを初期化
all_tags_for_video = []

# バッチサイズ、上位N個のタグを選択するための数、類似度のしきい値を設定
BATCH_SIZE = 16
N = 3  # 上位N個のタグを選択
SIMILARITY_THRESHOLD = 0.2  # 類似度のしきい値

# バッチ処理用のフレームリストを初期化
frame_batch = []

# 出力されるフレームを保存するフォルダの設定
output_folder = 'output_frames'
if not os.path.exists(output_folder):
    os.makedirs(output_folder)  # フォルダが存在しない場合は作成
else:
    delete_images_in_folder(output_folder)  # フォルダが存在する場合は中の画像を全て削除

# 動画の各フレームを処理
frame_count = 0
while cap.isOpened():
    ret, frame = cap.read()  # フレームを読み込み
    if not ret:
        break  # フレームがない場合は終了

    frame_batch.append(frame)  # バッチリストにフレームを追加
    # バッチサイズに達したら処理を実行
    if len(frame_batch) == BATCH_SIZE:
        process_batch(frame_batch, frame_count - len(frame_batch) + 1, model, transform, 
                      text_features, japanese_tags, japanese_to_english_dict, 
                      all_tags_for_video, output_folder, fps)
        frame_batch = []  # 処理後はバッチリストをリセット
    frame_count += 1

# 残りのフレームを処理
if frame_batch:
    process_batch(frame_batch, frame_count - len(frame_batch) + 1, model, transform, 
                  text_features, japanese_tags, japanese_to_english_dict, 
                  all_tags_for_video, output_folder, fps)

cap.release()  # 動画の読み込みを終了

# タグの出現回数を集計し、ファイルに出力
tag_counts = Counter(all_tags_for_video)
with open('output_tags.txt', 'w', encoding='utf-8') as f:
    for tag, count in tag_counts.most_common():
        f.write(f"{tag}: {count}\n")  # タグとその出現回数をファイルに書き込み

print("動画の処理が完了しました。")

4. 実行結果と分析

タグとその出現回数は以下となります。

投げる: 832
踊る: 479
走る: 238

いいですね。多くは「投げる」と判断しています。

「踊る」「走る」はどのようなポイントで判断されているのが気になるところです。いくつかピックアップしてみます。

① 走る(run) 86%

まあ確かにこの画像だけを見ると走っているように見えますね。

② 投げる(throw) 50% & 走る(run) 41%

腕の部分は投げている雰囲気を出しています。

③ 投げる(throw) 75%

投げてます!

④ 踊る(dance) 74%

珍妙なダンスですが... 投げても走ってもいないですね。

5. 最後に

動画を読み込んで解析する場合、各フレームの静止画像に対して解析することになるので、上記のようにポイントでは誤った判断をすることがあります。そのため、全体を通してどのタグが一番多く検出されたのかを見ることで最終的な判断とすることがよさそうです。