Source code for fedscale.cloud.execution.data_processor

import torch
from torch.nn.utils.rnn import pad_sequence

from fedscale.cloud.fllibs import *


[docs]def collate(examples): if tokenizer._pad_token is None: return (pad_sequence(examples, batch_first=True), None) return (pad_sequence(examples, batch_first=True, padding_value=tokenizer.pad_token_id), None)
[docs]def voice_collate_fn(batch): def func(p): return p[0].size(1) batch = sorted(batch, key=lambda sample: sample[0].size(1), reverse=True) longest_sample = max(batch, key=func)[0] freq_size = longest_sample.size(0) minibatch_size = len(batch) max_seqlength = longest_sample.size(1) inputs = torch.zeros(minibatch_size, 1, freq_size, max_seqlength) input_percentages = torch.FloatTensor(minibatch_size) target_sizes = torch.IntTensor(minibatch_size) targets = [] for x in range(minibatch_size): sample = batch[x] tensor = sample[0] target = sample[1] seq_length = tensor.size(1) inputs[x][0].narrow(1, 0, seq_length).copy_(tensor) input_percentages[x] = seq_length / float(max_seqlength) target_sizes[x] = len(target) targets.extend(target) targets = torch.IntTensor(targets) return (inputs, targets, input_percentages, target_sizes), None