torch.renorm operation in numpy
2023-07-24
I have made a numpy implementation of torch.renorm
operation.
import numpy as np
def renorm(x, p, dim, maxnorm):
x_view = np.rollaxis(x, dim, 0)
n = x.shape[dim]
norms = []
for i in range(n):
norms.append(np.linalg.norm(x_view[i,:], ord=2))
factors = []
for norm in norms:
if norm > maxnorm:
factors.append(maxnorm/norm)
else:
factors.append(1)
factors = np.array(factors)
return x * factors.reshape(-1, 1), factors