GBPhone / dodemo.py
mhuckvale's picture
Upload 4 files
e321f93 verified
from transformers import Wav2Vec2CTCTokenizer
tokenizer = Wav2Vec2CTCTokenizer("./GBPhone/vocab.json", unk_token="[UNK]", pad_token="[PAD]", word_delimiter_token="|",sep_token=" ")
from transformers import Wav2Vec2FeatureExtractor
feature_extractor = Wav2Vec2FeatureExtractor(feature_size=1, sampling_rate=16000, padding_value=0.0, do_normalize=True, return_attention_mask=False)
from transformers import Wav2Vec2Processor
processor = Wav2Vec2Processor(feature_extractor=feature_extractor, tokenizer=tokenizer)
from transformers import Wav2Vec2ForCTC
import torch
print("CUDA available:",torch.cuda.is_available())
path = './GBPhone/checkpoint-2300'
finetuned_model = Wav2Vec2ForCTC.from_pretrained(path)
def map_to_result(batch):
finetuned_model.to("cuda")
input_values = processor(
batch["speech"],
sampling_rate=batch["sampling_rate"],
return_tensors="pt"
).input_values.to("cuda")
with torch.no_grad():
logits = finetuned_model(input_values).logits
batch["logits"] = logits
batch["pred_ids"] = torch.argmax(logits, dim=-1)
batch["pred_str"] = processor.batch_decode(batch["pred_ids"],skip_special_tokens=True,spaces_between_special_tokens=True)
return batch
# load files and recognise
import librosa as lb
import numpy as np
import glob
arpa2sampa={ "aa":"A:","ae":"{","ah":"V","ao":"O:","aw":"aU","ax":"@","ay":"aI","b":"b","ch":"tS","d":"d","dh":"D","ea":"e@","eh":"e","er":"3:","ey":"eI","f":"f","g":"g","hh":"h","ia":"I@","ih":"I","iy":"i:","jh":"dZ","k":"k","l":"l","m":"m","n":"n","ng":"N","oh":"Q","ow":"@U","oy":"OI","p":"p","r":"r","s":"s","sh":"S","sil":"/","t":"t","th":"T","ua":"U@","uh":"U","uw":"u:","v":"v","w":"w","y":"j","z":"z","zh":"Z","[UNK]":"unk","[PAD]":"blk"}
sampa=[arpa2sampa[tokenizer.convert_ids_to_tokens(x)] for x in range(49)]
flist = glob.glob('*.wav')
for fname in flist:
speech_array, sampling_rate = lb.load(fname,sr=16000)
print("Loaded %s with %.2fs at %gHz" % (fname,len(speech_array)/sampling_rate,sampling_rate))
results=map_to_result({ "speech":speech_array, "sampling_rate":sampling_rate})
# save logits as text
ltab=results["logits"].cpu()
ltab=ltab.numpy()
ltab=ltab[0,:,:]
times=0.02*np.array(list(range(ltab.shape[0])))
ltab2=np.insert(ltab,0,times,axis=1);
header="Time," + ','.join('"%s"' % sampa[x] for x in range(49))
cname=fname.replace(".wav",".csv")
np.savetxt(cname, ltab2, fmt="%.4f", delimiter=",", header=header,comments='')