From 743c1a0b12fa3b19fbaaca080750a27206df974f Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Sat, 27 Oct 2018 03:16:12 +0200 Subject: [PATCH 1/2] Cholesky, LU, QR, triangularSolve, setDiag, diagPart, broadcastTo, conj matMul now supports broadcasting --- src/kernels/backend.ts | 5 + src/kernels/backend_cpu.ts | 95 +++ src/kernels/backend_webgl.ts | 22 + src/kernels/webgl/band_part_gpu.ts | 53 ++ src/kernels/webgl/diag_part_gpu.ts | 52 ++ src/kernels/webgl/set_diag_gpu.ts | 58 ++ src/ops/array_ops.ts | 57 ++ src/ops/array_ops_test.ts | 60 +- src/ops/complex_ops.ts | 18 + src/ops/complex_ops_test.ts | 29 + src/ops/linalg_ops.ts | 1278 +++++++++++++++++++++++++--- src/ops/linalg_ops_test.ts | 675 ++++++++++++--- src/ops/linalg_util.ts | 54 ++ src/ops/matmul.ts | 7 +- src/test_util.ts | 45 +- 15 files changed, 2246 insertions(+), 262 deletions(-) create mode 100644 src/kernels/webgl/band_part_gpu.ts create mode 100644 src/kernels/webgl/diag_part_gpu.ts create mode 100644 src/kernels/webgl/set_diag_gpu.ts create mode 100644 src/ops/linalg_util.ts diff --git a/src/kernels/backend.ts b/src/kernels/backend.ts index f8a00c5b8e..4401ef0a4e 100644 --- a/src/kernels/backend.ts +++ b/src/kernels/backend.ts @@ -87,6 +87,11 @@ export interface KernelBackend extends TensorStorage, BackendTimer { /** Returns the highest precision for floats in bits (e.g. 16 or 32) */ floatPrecision(): number; + matrixSetDiag( a: T, d: Tensor ): T; + + matrixDiagPart( a: Tensor ): Tensor; + matrixBandPart( a: T, numLower: number, numUpper: number ): T; + batchMatMul( a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean): Tensor3D; diff --git a/src/kernels/backend_cpu.ts b/src/kernels/backend_cpu.ts index 1ac27ff79f..1cedcef39e 100644 --- a/src/kernels/backend_cpu.ts +++ b/src/kernels/backend_cpu.ts @@ -385,6 +385,101 @@ export class MathBackendCPU implements KernelBackend { T; } + matrixSetDiag( a: T, d: Tensor ): T + { + if( a.dtype != d.dtype ) throw new Error(`setDiag(): Both tensors must have the same dtype.`); + if( a.rank < 2 ) throw new Error(`setDiag(): a.rank=${a.rank} < 2`); + if( d.rank < 1 ) throw new Error(`setDiag(): d.rank=${d.rank} < 1`); + if( a.shape.some( d => d < 0 || d%1 !== 0 ) ) throw new Error(`setDiag(): Invalid input a.shape [${a.shape}].`); + if( d.shape.some( d => d < 0 || d%1 !== 0 ) ) throw new Error(`setDiag(): Invalid input d.shape [${d.shape}].`); + if( a.shape.length != d.shape.length+1 ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${a.shape}] [${d.shape}]`) + + for( let i=a.rank-2; i-- > 0; ) + if( a.shape[i] != d.shape[i] ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${a.shape}] [${d.shape}]`) + + if( d.shape[d.rank-1] != Math.min( ...a.shape.slice(-2) ) ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${a.shape}] [${d.shape}]`) + + const [M,N] = a.shape.slice(-2), + L = Math.min(M,N), + dtype = a.dtype; + + const wordsPerElem = dtype.startsWith('complex') ? 2 : 1, + A_shape = a.shape, + A = a.dataSync().slice(); a=undefined; + const D = d.dataSync() ; d=undefined; + + for( let A_off=0, + D_off=0; D_off < D.length; D_off += L, + A_off += M*N ) + { + for( let i=0; i < M; i++ ) + for( let k=0; k < wordsPerElem; k++ ) + A[wordsPerElem*(A_off + N*i+i)+k] = D[wordsPerElem*(D_off + i)+k]; + } + + return Tensor.make(A_shape,{values: A},dtype); + } + + matrixDiagPart( a: Tensor ): Tensor + { + if( a.rank < 2 ) throw new Error('diagPart(): Input a.rank must be at least 2.'); + if( a.shape.some( d => d < 0 || + d%1 !== 0 ) ) throw new Error(`diagPart(): Invalid input shape [${a.shape}].`); + + const D_shape = a.shape.slice(0,-1), + [M,N] = a.shape.slice(-2), + L = Math.min(M,N), + dtype = a.dtype, + wordsPerElem = dtype.startsWith('complex') ? 2 : 1, + A = a.dataSync(); + + D_shape[D_shape.length-1] = Math.min( ...a.shape.slice(-2) ); + Object.freeze(D_shape); + + const D: TypedArray = new (A).constructor( D_shape.reduce( (a,b) => a*b ) ); a = undefined; + + for( let A_off=0, + D_off=0; D_off < D.length; D_off += L, + A_off += M*N ) + { + for( let i=0; i < L; i++ ) + for( let k=0; k < wordsPerElem; k++ ) + D[wordsPerElem*(D_off + i) + k] = A[wordsPerElem*(A_off + N*i+i) + k]; + } + + return Tensor.make(D_shape, {values: D}, dtype); + } + + matrixBandPart( a: T, numLower: number, numUpper: number ): T + { + util.assert( numLower%1 == 0, `Error in bandPart: numLower=${numLower} is no integer.`); + util.assert( numUpper%1 == 0, `Error in bandPart: numUpper=${numUpper} is no integer.`); + util.assert( numLower >= 0, `Error in bandPart: numLower=${numLower} is negative.`); + util.assert( numUpper >= 0, `Error in bandPart: numUpper=${numUpper} is negative.`); + util.assert( a.rank >= 2, `Error in bandPart: a.rank=${a.rank} must not be less than 2.`); + + const B_shape = Array.from(a.shape), + [M,N] = a.shape.slice(-2), + dtype = a.dtype, + wordsPerElem = dtype.startsWith('complex') ? 2 : 1, + B = a.dataSync().slice(); a = undefined; + + Object.freeze(B_shape); + + for( let off=0; off < B.length; off += M*N ) + for( let i=0; i < M; i++ ) + for( let j=0; j < N; j++ ) + if( j-i > numUpper || + i-j > numLower ) + for( let k=0; k < wordsPerElem; k++ ) + B[wordsPerElem*(off + N*i+j) + k] = 0; + + return Tensor.make(B_shape,{values: B},dtype); + } + batchMatMul( a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean): Tensor3D { diff --git a/src/kernels/backend_webgl.ts b/src/kernels/backend_webgl.ts index 123c783727..c966871744 100644 --- a/src/kernels/backend_webgl.ts +++ b/src/kernels/backend_webgl.ts @@ -49,6 +49,9 @@ import * as binaryop_complex_gpu from './webgl/binaryop_complex_gpu'; import {BinaryOpComplexProgram} from './webgl/binaryop_complex_gpu'; import * as binaryop_gpu from './webgl/binaryop_gpu'; import {BinaryOpProgram} from './webgl/binaryop_gpu'; +import {SetDiagProgram} from './webgl/set_diag_gpu'; +import {BandPartProgram} from './webgl/band_part_gpu'; +import {DiagPartProgram} from './webgl/diag_part_gpu'; import {ClipProgram} from './webgl/clip_gpu'; import {ComplexAbsProgram} from './webgl/complex_abs_gpu'; import {ConcatProgram} from './webgl/concat_gpu'; @@ -593,6 +596,25 @@ export class MathBackendWebGL implements KernelBackend { return this.compileAndRun(program, [x]) as T; } + matrixSetDiag( a: T, d: Tensor ): T + { + if( a.dtype != d.dtype ) throw new Error(`setDiag(): Both tensors must have the same dtype.`); + const program = new SetDiagProgram(a.shape,d.shape); + return this.compileAndRun(program, [a,d]); + } + + matrixDiagPart( a: Tensor ): Tensor + { + const program = new DiagPartProgram(a.shape); + return this.compileAndRun(program, [a]); + } + + matrixBandPart( a: T, numLower: number, numUpper: number ): T + { + const program = new BandPartProgram(a.shape, numLower, numUpper); + return this.compileAndRun(program, [a]); + } + batchMatMul( a: Tensor3D, b: Tensor3D, transposeA: boolean, transposeB: boolean): Tensor3D { diff --git a/src/kernels/webgl/band_part_gpu.ts b/src/kernels/webgl/band_part_gpu.ts new file mode 100644 index 0000000000..b299566e83 --- /dev/null +++ b/src/kernels/webgl/band_part_gpu.ts @@ -0,0 +1,53 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {GPGPUProgram} from './gpgpu_math'; +import {getCoordsDataType} from './shader_compiler'; + +export class BandPartProgram implements GPGPUProgram { + variableNames = ['A']; + outputShape: number[]; + userCode: string; + rank: number; + + constructor( aShape: number[], numLower: number, numUpper: number ) { + const rank = aShape.length; + this.outputShape = Array.from(aShape); + this.rank = rank;; + + if( numLower%1 !== 0 ) throw new Error(`bandPart(): numLower=${numLower} is no integer.`); + if( numUpper%1 !== 0 ) throw new Error(`bandPart(): numUpper=${numUpper} is no integer.`); + if( numLower < 0 ) throw new Error(`bandPart(): numLower=${numLower} is negative.`); + if( numUpper < 0 ) throw new Error(`bandPart(): numUpper=${numUpper} is negative.`); + + if( rank < 2 ) throw new Error(`bandPart(): a.rank=${rank} must not be less than 2.`); + if( rank > 6 ) throw new Error(`bandPart(): a.rank=${rank} not yet supported.`); + + const dtype = getCoordsDataType(rank), + idx = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'].slice(0,rank), + [i,j] = idx.slice(-2); + + this.userCode = ` + void main() { + ${dtype} resRC = getOutputCoords(); + if( ${j}-${i} > ${numUpper} ) setOutput(0.0); + else if( ${i}-${j} > ${numLower} ) setOutput(0.0); + else setOutput( getA(${idx.join()}) ); + } + `; + } +} diff --git a/src/kernels/webgl/diag_part_gpu.ts b/src/kernels/webgl/diag_part_gpu.ts new file mode 100644 index 0000000000..c3df615955 --- /dev/null +++ b/src/kernels/webgl/diag_part_gpu.ts @@ -0,0 +1,52 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {GPGPUProgram} from './gpgpu_math'; +import {getCoordsDataType} from './shader_compiler'; + +export class DiagPartProgram implements GPGPUProgram { + variableNames = ['A']; + outputShape: number[]; + userCode: string; + rank: number; + + constructor( aShape: number[] ) { + if( aShape.some( d => d < 0 || d%1 !== 0 ) ) throw new Error(`diagPart(): Invalid input shape [${aShape}].`); + + const rank = aShape.length-1; + this.rank = rank; + + if( rank < 1 ) throw new Error('diagPart(): Input rank must be at least 2.'); + + this.outputShape = aShape.slice(0,-1); + this.outputShape[rank-1] = Math.min( ...aShape.slice(-2) ); + + if( rank > 5 ) throw Error(`diagPart(): a.rank=${rank+1} is not yet supported.`); + + const dtype = getCoordsDataType(rank), + idx = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'].slice(0,rank); + if( 1 === rank ) idx[0] = 'resRC'; + idx.push( idx[rank-1] ); + + this.userCode = ` + void main() { + ${dtype} resRC = getOutputCoords(); + setOutput( getA(${idx.join()}) ); + } + `; + } +} diff --git a/src/kernels/webgl/set_diag_gpu.ts b/src/kernels/webgl/set_diag_gpu.ts new file mode 100644 index 0000000000..bc12e447bc --- /dev/null +++ b/src/kernels/webgl/set_diag_gpu.ts @@ -0,0 +1,58 @@ +/** + * @license + * Copyright 2017 Google Inc. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +import {GPGPUProgram} from './gpgpu_math'; +import {getCoordsDataType} from './shader_compiler'; + +export class SetDiagProgram implements GPGPUProgram { + variableNames = ['A','D']; + outputShape: number[]; + userCode: string; + rank: number; + + constructor( aShape: number[], dShape: number[] ) { + if( aShape.length < 2 ) throw new Error(`setDiag(): a.rank=${aShape.length} < 2`); + if( dShape.length < 1 ) throw new Error(`setDiag(): d.rank=${dShape.length} < 1`); + if( aShape.some( d => d < 0 || d%1 !== 0 ) ) throw new Error(`setDiag(): Invalid input a.shape [${aShape}].`); + if( dShape.some( d => d < 0 || d%1 !== 0 ) ) throw new Error(`setDiag(): Invalid input d.shape [${dShape}].`); + if( aShape.length != dShape.length+1 ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${aShape}] [${dShape}]`) + + for( let i=aShape.length-2; i-- > 0; ) + if( aShape[i] != dShape[i] ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${aShape}] [${dShape}]`) + + if( dShape[dShape.length-1] != Math.min( ...aShape.slice(-2) ) ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${aShape}] [${dShape}]`) + + this.rank = aShape.length; + this.outputShape = aShape; + + const dtype = getCoordsDataType(this.rank), + idx = ['resRC.x', 'resRC.y', 'resRC.z', 'resRC.w', 'resRC.u', 'resRC.v'].slice(0,this.rank), + [i,j] = idx.slice(-2); + + this.userCode = ` + void main() { + ${dtype} resRC = getOutputCoords(); + if( ${i} == ${j} ) setOutput( getD(${idx.slice(0,-1).join()}) ); + else setOutput( getA(${idx .join()}) ); + } + `; + } +} + diff --git a/src/ops/array_ops.ts b/src/ops/array_ops.ts index 622483b06d..952c762b54 100644 --- a/src/ops/array_ops.ts +++ b/src/ops/array_ops.ts @@ -103,6 +103,58 @@ function eye_( } } +/** Broadcast an array to a compatible shape NumPy-style. + * + * The tensor's shape is compared to the broadcast shape from end to beginning. + * Ones are prepended to the tensor's shape until is has the same length as + * the broadcast shape. If input.shape[i]==shape[i], they (i+1)-th axis is + * already broadcast-compatible. If input.shape[i]==1 and shape[i]==N, then + * the input tensor is tiled N times along that axis (using tf.tile). + * + * @param input The tensor that is to be broadcasted. + * @param shape The input is to be broadcast to this shape. + */ +/** @doc {heading: 'Tensors', subheading: 'Transformations'} */ +function broadcastTo_( x: Tensor|TensorLike, shape: ShapeMap[R] ): Tensor +{ + let input = convertToTensor(x, 'broadcastTo', 'x'); + const x_shape = input.shape; + + if( shape.some( d => d < 0 ) ) + throw new Error(`broadcastTo(): Invalid broadcast shape [${shape}].`); + + if( shape.length < input.rank ) throw new Error(`broadcastTo(): shape.length=${shape.length} < input.rank=${input.rank}.`); + if( shape.length > input.rank ) + { + const newShape = input.shape.slice(); + while( newShape.length < shape.length ) + newShape.unshift(1); + input = input.reshape(newShape); + } + + const reps: number[] = Array.from(shape); + for( let i=shape.length; i-- > 0; ) + { + if( input.shape[i] === shape[i] ) + reps[i] = 1; + else if( input.shape[i] !== 1 ) + throw new Error(`broadcastTo(): [${x_shape}] cannot be not broadcast to [${shape}].`); + } + + const axes = reps.map( ( n,i) => n > 1 ? i : -1 ).filter( i => i >= 0 ); + + if( axes.length === 0 ) + return input as Tensor; + + return ENV.engine.runKernel( + backend => backend.tile(input,reps), + {input}, + (dy: Tensor) => ({ + input: () => dy.sum(axes,/*keepDims=*/true) + }) + ) as Tensor; +} + /** * Creates a `Tensor` with values sampled from a normal distribution. * @@ -545,6 +597,10 @@ function tile_(x: T|TensorLike, reps: number[]): T { $x.rank === reps.length, `Error in transpose: rank of input ${$x.rank} ` + `must match length of reps ${reps}.`); + + if( reps.every( d => d === 1 ) ) + return $x; + const grad = (dy: T) => { const derX = () => { let xGrad = zerosLike($x); @@ -1149,6 +1205,7 @@ export const cumsum = op({cumsum_}); export const depthToSpace = op({depthToSpace_}); export const expandDims = op({expandDims_}); export const eye = op({eye_}); +export const broadcastTo = op({broadcastTo_}); export const fromPixels = op({fromPixels_}); export const multinomial = op({multinomial_}); export const oneHot = op({oneHot_}); diff --git a/src/ops/array_ops_test.ts b/src/ops/array_ops_test.ts index fa5fd1f3f1..36104a2bc5 100644 --- a/src/ops/array_ops_test.ts +++ b/src/ops/array_ops_test.ts @@ -17,8 +17,9 @@ import * as tf from '../index'; import {describeWithFlags} from '../jasmine_util'; -import {ALL_ENVS, BROWSER_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange, NODE_ENVS, WEBGL_ENVS} from '../test_util'; +import {numDiff, ALL_ENVS, BROWSER_ENVS, CPU_ENVS, expectArraysClose, expectArraysEqual, expectPromiseToFail, expectValuesInRange, NODE_ENVS, WEBGL_ENVS} from '../test_util'; import * as util from '../util'; +import {Tensor, Scalar} from '../tensor'; import {expectArrayInMeanStdRange, jarqueBeraNormalityTest} from './rand_util'; @@ -1922,6 +1923,63 @@ describeWithFlags('clone', ALL_ENVS, () => { }); }); +describeWithFlags('broadcastTo', ALL_ENVS, () => { + + it('[] -> [3,2]', () => { + const a = tf.scalar(2); + const A = tf.tensor2d([[2,2], + [2,2], + [2,2]]); + + const w = tf.randomUniform(A.shape); + const f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean() as Scalar; + + expectArraysEqual( A, tf.broadcastTo(a,A.shape) ); + + const g = tf.grad(f), + h = numDiff(f); + + expectArraysClose( g(a), h(a) ); + }); + + it('[2] -> [3,2]', () => { + const a = tf.tensor1d([1,2]); + const A = tf.tensor2d([[1,2], + [1,2], + [1,2]]); + + const w = tf.randomUniform(A.shape); + const f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean() as Scalar; + + expectArraysEqual( A, tf.broadcastTo(a,A.shape) ); + + const g = tf.grad(f), + h = numDiff(f); + + expectArraysClose( g(a), h(a) ); + }); + + it('[3,1] -> [3,2]', () => { + const a = tf.tensor2d([[1], + [2], + [3]]); + const A = tf.tensor2d([[1,1], + [2,2], + [3,3]]); + + const w = tf.randomUniform(A.shape); + const f = (a: Tensor) => tf.broadcastTo(a,A.shape).mul(w).mean() as Scalar; + + expectArraysEqual( A, tf.broadcastTo(a,A.shape) ); + + const g = tf.grad(f), + h = numDiff(f); + + expectArraysClose( g(a), h(a) ); + }); + +}); + describeWithFlags('tile', ALL_ENVS, () => { it('1D (tile)', () => { const t = tf.tensor1d([1, 2, 3]); diff --git a/src/ops/complex_ops.ts b/src/ops/complex_ops.ts index 89b09f6a14..dc5b6236a7 100644 --- a/src/ops/complex_ops.ts +++ b/src/ops/complex_ops.ts @@ -52,6 +52,23 @@ function complex_(real: T|TensorLike, imag: T|TensorLike): T { backend => backend.complex($real, $imag), {$real, $imag}) as T; } +/** Computes the complex conjucate element-wise. + * + * @param x The input tensor. + * @returns The complex conjugate of `x`. + */ +/** @doc {heading: 'Operations', subheading: 'Basic math'} */ +function conj_( x: T ) : T { + if( ! x.dtype.startsWith('complex') ) + return x; + + return ENV.engine.runKernel( + backend => complex( real(x), imag(x).neg() ), + {x}, + dy => ({ x: () => conj(dy) }) + ); +} + /** * Returns the real part of a complex (or real) tensor. * @@ -92,5 +109,6 @@ function imag_(input: T|TensorLike): T { } export const complex = op({complex_}); +export const conj = op({conj_}); export const real = op({real_}); export const imag = op({imag_}); diff --git a/src/ops/complex_ops_test.ts b/src/ops/complex_ops_test.ts index 81f590e9c3..1c8c923744 100644 --- a/src/ops/complex_ops_test.ts +++ b/src/ops/complex_ops_test.ts @@ -57,6 +57,35 @@ describeWithFlags('complex64', ALL_ENVS, () => { }); }); +describeWithFlags('conj', ALL_ENVS, () => { + it('1D non-complex', () => { + const a = tf.tensor1d([1, -3, 2, 7, -4]); + expectArraysClose( tf.conj(a), a ); + }); + + it('2D non-complex', () => { + const a = tf.tensor2d([[1, -3, 2], + [-8, 7, -4]]); + expectArraysClose( tf.conj(a), a ); + }); + + it('1D', () => { + const a = tf.complex([1,2,7],[-3,+8,-4]), + b = tf.complex([1,2,7],[+3,-8,+4]); + expectArraysClose( tf.conj(a), b ); + }); + + it('2D', () => { + const a = tf.complex([[ 1, -3, 2], + [-8, 7, 0]], [[42,-1, 2], + [ 3, 0,-4]]); + const b = tf.complex([[ 1, -3, 2], + [-8, 7, 0]], [[-42,+1,-2], + [ -3,-0,+4]]); + expectArraysClose( tf.conj(a), b ); + }); +}); + const BYTES_PER_COMPLEX_ELEMENT = 4 * 2; describeWithFlags('complex64 memory', BROWSER_ENVS, () => { it('usage', () => { diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index 51daf3ffa6..d31e5f8954 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -20,15 +20,25 @@ */ import {ENV} from '../environment'; -import {dispose} from '../globals'; +// import {dispose} from '../globals'; +import {convertToTensor} from '../tensor_util_env'; import {Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {TensorLike} from '../types'; import {assert} from '../util'; -import {eye, squeeze, stack, unstack} from './array_ops'; +import {squeeze, stack, broadcastTo} from './array_ops'; +import {matMul} from './matmul'; import {split} from './concat_split'; import {norm} from './norm'; import {op} from './operation'; import {sum} from './reduction_ops'; -import {tensor2d} from './tensor_ops'; +import {TypedArray} from '../types'; +import {broadcastMatrices} from './linalg_util'; +import {conj} from './complex_ops'; +import {zeros, ones} from './tensor_ops'; +import {upcastType} from '../types'; +import {scalar, range} from './tensor_ops'; +import {gather} from './segment_ops'; +import {add} from './binary_ops'; /** * Gram-Schmidt orthogonalization. @@ -106,158 +116,1146 @@ function gramSchmidt_(xs: Tensor1D[]|Tensor2D): Tensor1D[]|Tensor2D { } } -/** - * Compute QR decomposition of m-by-n matrix using Householder transformation. - * - * Implementation based on - * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] - * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) + +function lu_p_decomp_( a: Tensor ): [Tensor,Tensor] +{ + assert( a.rank >= 2, `Error in linalg.lu: input must have rank >= 2, got rank ${a.rank}.`); + assert( a.shape[a.rank-2] == a.shape[a.rank-1],`Error in linalg.lu: input must be square, got shape [${a.shape}].`) + assert( ! a.dtype.startsWith('complex'), `lu(): complex dtypes not supported.`); + assert( ! a.dtype.startsWith('int' ), `lu(): integer dtypes not supported.`); + + const DType = a.dtype, + LU_shape = Array.from( a.shape ), + P_shape = LU_shape.slice(0,-1), + [N] = LU_shape.slice(-1), + LU = a.dataSync().slice(), + P = new Int32Array(LU.length/N); + + for( let LU_off=0, + P_off=0; P_off < P.length; P_off += N, + LU_off += N*N ) + { + // INIT P + for( let i=0; i < N; i++ ) P[P_off+i] = i; + // LU DECOMPOSITION + for( let i=0; i < N; i++ ) + { + const row_i = LU_off + i*N; + // ROW PIVOTING + { + let p=i; + for( let j=i+1; j < N; j++ ) + if( Math.abs( LU[LU_off+P[P_off+j]*N+i] ) + > Math.abs( LU[LU_off+P[P_off+p]*N+i] ) ) + p=j; + if( i != p ) { + const P_p = P[P_off+i]; + P[P_off+i] = P[P_off+p]; + P[P_off+p] = P_p; // KEEP TRACK OF ROW SWAPS + const row_p = LU_off + p*N; + // SWAP ROWS + for( let j=0; j < N; j++ ) { + const tmp = LU[row_i+j]; + LU[row_i+j] = LU[row_p+j]; + LU[row_p+j] = tmp; + } + } + } + // ELIMINATE ELEMENTS BELOW PIVOT + for( let j=i+1; j < N; j++ ) + { + const row_j = LU_off + j*N, + scale = LU[row_j+i] / LU[row_i+i]; + LU[row_j+i] = scale; + for( let k=i+1; k < N; k++ ) + LU[row_j+k] -= scale * LU[row_i+k]; + } + } + } + + const lu = Tensor.make(LU_shape,{ values: LU }, DType ); + const p = Tensor.make( P_shape,{ values: P },'int32'); + + return [lu,p]; +} + + +/** Computes the economic QR Decomposition. + */ +function qr_eco_decomp_( a: Tensor ): [Tensor,Tensor] +{ + assert( a.rank >= 2, `Error in linalg.qr: input must have rank >= 2, got rank ${a.rank}.`); + assert( ! a.dtype.startsWith('complex'), `Error in linalg.qr: complex dtype not supported.`); + assert( a.shape[a.rank-2] >= a.shape[a.rank-1], `Error in linalg.qr: a.shape[-2] = ${a.shape[a.rank-2]} < ${a.shape[a.rank-1]} = a.shape[-1].` ); + + const DType = 'float32', + DTypeArray = Float32Array, // <- ensure at least double precision + Q_shape = Array.from( a.shape ), + R_shape = Array.from( Q_shape ), + [N,M] = Q_shape.slice(-2); + R_shape[R_shape.length-2] = M; + Object.freeze(Q_shape); + Object.freeze(R_shape); + + const Q = DTypeArray.from( a.dataSync() ); a = undefined; // <- might encourage GC by setting `A = undefined` after this line + const R = new DTypeArray(Q.length/N*M), + cs = new DTypeArray(M*2),// <- CACHE cos AND sin VALUES TO APPLY M COLUMN ROTATIONS TO Q AT ONCE + r = function(){ + try { return cs.subarray(M); } + catch(e) { return new DTypeArray(M); } + }(); // <- additional space to temp. store rows of R not contained in the result + + for( + let R_off=0, + Q_off=0; Q_off < Q.length; Q_off += N*M, + R_off += M*M + ) + { + // HANDLE ENTRIES CONTAINED IN THE RESULT + for( let i=0; i < M; i++ ) + { + // COPY FROM Q TO R AND INIT Q + for( let j=0; j < M; j++ ) { + R[R_off+M*i+j] = Q[Q_off+M*i+j]; Q[Q_off+M*i+j] = i != j ? 0.0 : 1.0 + } + + for( let j=0; j < i; j++ ) + { // USE GIVENS ROTATION TO ELIMINATE ELEMENT R_ji + const R_ij = R[R_off+M*i+j]; if( R_ij == 0.0 ) { cs[2*j+0]=1.0; cs[2*j+1]=0.0; continue }; + const R_jj = R[R_off+M*j+j], + norm = Math.hypot(R_jj,R_ij), + c = R_jj / norm, + s = R_ij / norm; + cs[2*j+0] = c; + cs[2*j+1] = s; + R[R_off + M*i+j] = 0.0; + R[R_off + M*j+j] = norm; + // ROTATE ROW i AND j IN R + for( let k=j; ++k < M; ) { + const ik = R_off+M*i+k, R_ik = R[ik], + jk = R_off+M*j+k, R_jk = R[jk]; + R[ik] = c*R_ik - s*R_jk; + R[jk] = s*R_ik + c*R_jk; + } + } + + // ROTATE COLUMNS IN Q (BUNDLED FOR BETTER CACHE LOCALITY) + for( let k=0; k <= i; k++ ) + for( let j=0; j < i; j++ ) { + const c = cs[2*j+0], + s = cs[2*j+1], + ki = Q_off+M*k+i, Q_ki = Q[ki], + kj = Q_off+M*k+j, Q_kj = Q[kj]; + Q[ki] = c*Q_ki - s*Q_kj; + Q[kj] = s*Q_ki + c*Q_kj; + } + } + // HANDLE REMAINING ENTRIES NOT CONTAINED IN THE RESULT + for( let i=M; i < N; i++ ) + { + // INIT r + for( let j=0; j < M; j++ ) { + r[j] = Q[Q_off+M*i+j]; Q[Q_off+M*i+j] = 0.0; + } + + // USE GIVENS ROTATIONS TO ELIMINATE ELEMENT r completely + for( let j=0; j < M; j++ ) + { + const r_j = r[j]; if( r_j == 0.0 ) { cs[2*j+0]=1.0; cs[2*j+1]=0.0; continue }; + const R_jj = R[R_off+M*j+j], + norm = Math.hypot(R_jj,r_j), + c = R_jj / norm, + s = r_j / norm; + R[R_off+M*j+j] = norm; + // ROTATE ROW i AND j IN R + for( let k=j; ++k < M; ) { + const jk = R_off + M*j+k, R_jk = R[jk]; + R[jk] = s*r[k] + c*R_jk; + r[ k] = c*r[k] - s*R_jk; + } + cs[2*j+0] = c; + cs[2*j+1] = s; + } + + // ROTATE COLUMNS IN Q + for( let k=0; k <= i; k++ ) { let Q_k = i != k ? 0.0 : 1.0; + for( let j=0; j < M; j++ ) { + const c = cs[2*j+0], + s = cs[2*j+1], q_k = Q_k, + kj = Q_off + M*k+j, Q_kj = Q[kj]; + Q_k = c*q_k - s*Q_kj; + Q[kj]= s*q_k + c*Q_kj; + }} + } + } + + { + const q = Tensor.make(Q_shape, { values: Q }, DType); + const r = Tensor.make(R_shape, { values: R }, DType); + + return [q,r]; + } +} + + +/** Computes the full QR Decomposition an memoizes the Givens rotation angles in the process. + */ +function qr_full_decomp_( a: Tensor ): [Tensor,Tensor,Tensor] +{ + assert( a.rank >= 2, `Error in linalg.qr: input must have rank >= 2, got rank ${a.rank}.`); + assert( ! a.dtype.startsWith('complex'), `Error in linalg.qr: complex dtype not supported.`); + + const DType ='float32', + DTypeArray = Float32Array, + R_shape = Array.from( a.shape ), + Q_shape = Array.from( a.shape ), + [M,N] = a.shape.slice(-2), + R = DTypeArray.from( a.dataSync() ); + a = undefined; + const L = Math.min(M,N), + Q = new DTypeArray( R.length/N*M ), + SIN = new DTypeArray( R.length/N/M * ( (L*(L-1) >>> 1) + Math.max(0,M-N)*N ) ); + Q_shape[Q_shape.length-1] = M; + Object.freeze(Q_shape); + Object.freeze(R_shape); + + let l = 0; + for( let Q_off=0, + R_off=0; Q_off < Q.length; Q_off += M*M, + R_off += M*N ) + { + // INIT Q TO IDENTITY + for( let i=0; i < M; i++ ) Q[Q_off + M*i+i] = 1; + + // BEGIN QR DECOMPOSITION + for( let i=1; i < M; i++ ) { const J = Math.min(i,N); + for( let j=0; j < J; j++ ) + { + // DETERMINE GIVENS ROTATION cos AND sin + const R_ij = R[R_off + N*i+j]; if( 0.0 == R_ij ) { SIN[l++]=0.0; continue; } + const R_jj = R[R_off + N*j+j]; + let norm = Math.hypot(R_jj,R_ij), + c = R_jj / norm, + s = R_ij / norm; + // MAKE c POSITIVE SO THAT IS CAN BE DEDUCED VIA c = sqrt(1-s²) + if( c < 0 ) { + c *= -1; + s *= -1; + norm *= -1; + } + SIN[l++] = s; + R[R_off + N*j+j] = norm; + R[R_off + N*i+j] = 0; + // ROTATE ROWS IN R + for( let k=j; ++k < N; ) + { const R_jk = R[R_off + N*j+k], + R_ik = R[R_off + N*i+k]; + R[R_off + N*j+k] = s*R_ik + c*R_jk; + R[R_off + N*i+k] = c*R_ik - s*R_jk; + } + // ROTATE ROWS IN Qᵀ + for( let k=0; k <= i; k++ ) + { const Q_jk = Q[Q_off + M*j+k], + Q_ik = Q[Q_off + M*i+k]; + Q[Q_off + M*j+k] = s*Q_ik + c*Q_jk; + Q[Q_off + M*i+k] = c*Q_ik - s*Q_jk; + } + }} // END QR DECOMPOSITION + + // TRANSPOSE Q (was transposed for cache locality) + for( let i=0; i < M; i++ ) + for( let j=0; j < i; j++ ) { + const Q_ij = Q[Q_off + M*i+j]; + Q[Q_off + M*i+j] = Q[Q_off + M*j+i]; + Q[Q_off + M*j+i] = Q_ij; + } + } + assert( l == SIN.length, `WTF: ${l} != ${SIN.length}` ); + + const q = Tensor.make(Q_shape,{values: Q},DType); + const r = Tensor.make(R_shape,{values: R},DType); + const sin = Tensor.make([SIN.length],{values: SIN},DType); + + return [q,r,sin]; +} + + +/** Computes the backpropagation full QR Decomposition using memoized Givens rotation + * angles in the process. + */ +function qr_full_backprop_( q: Tensor, dq: Tensor, r: Tensor, dr: Tensor, sin: Tensor ): Tensor +{ + assert( q.rank == dq.rank, `q.rank == ${q.rank} != ${dq.rank} == dq.rank` ) + assert( q.rank == dr.rank, `q.rank == ${q.rank} != ${dr.rank} == dr.rank` ) + assert( q.rank == r.rank, `q.rank == ${q.rank} != ${ r.rank} == r.rank` ) + + assert( sin.rank == 1, `sin.rank == ${sin.rank} != 1` ) + + for( let i=q.rank-2; i-- > 0; ) + { + assert( q.shape[i] == dq.shape[i], `q.shape[${i}] == ${q.shape[i]} != ${dq.shape[i]} == dq.shape[${i}]` ) + assert( q.shape[i] == dr.shape[i], `q.shape[${i}] == ${q.shape[i]} != ${dr.shape[i]} == dr.shape[${i}]` ) + assert( q.shape[i] == r.shape[i], `q.shape[${i}] == ${q.shape[i]} != ${ r.shape[i]} == r.shape[${i}]` ) + } + const rank = q.rank; + assert( q.shape[rank-2] == q.shape[rank-1], `q.shape[-2] == ${q.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` ) + assert( q.shape[rank-2] == dq.shape[rank-1], `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-1]} == dq.shape[-1]` ) + assert( q.shape[rank-2] == dq.shape[rank-2], `q.shape[-2] == ${q.shape[rank-2]} != ${dq.shape[rank-2]} == dq.shape[-2]` ) + + assert( r.shape[rank-2] == q.shape[rank-1], `r.shape[-2] == ${r.shape[rank-2]} != ${ q.shape[rank-1]} == q.shape[-1]` ) + assert( r.shape[rank-1] == dr.shape[rank-1], `r.shape[-1] == ${r.shape[rank-1]} != ${dr.shape[rank-1]} == dr.shape[-1]` ) + assert( r.shape[rank-2] == dr.shape[rank-2], `r.shape[-2] == ${r.shape[rank-2]} != ${dr.shape[rank-2]} == dr.shape[-2]` ) + + assert( q.dtype == dq.dtype, `q.dtype == ${q.dtype} == ${ dq.dtype} == dq.dtype` ) + assert( q.dtype == dr.dtype, `q.dtype == ${q.dtype} == ${ dr.dtype} == dr.dtype` ) + assert( q.dtype == r.dtype, `q.dtype == ${q.dtype} == ${ r.dtype} == r.dtype` ) + assert( q.dtype == sin.dtype, `q.dtype == ${q.dtype} == ${sin.dtype} == sin.dtype` ) + + assert( ! q.dtype.startsWith('complex'), `Complex dtype not supported.`); + + const DType ='float32', + DTypeArray = Float32Array, + dA_shape = Array.from( r.shape ), + [M,N] = dA_shape.slice(-2); + const Q = DTypeArray.from( q.dataSync() );// q = undefined; + const dQ = DTypeArray.from( dq.dataSync() ); dq = undefined; + const R = DTypeArray.from( r.dataSync() );// r = undefined; + const dR = DTypeArray.from( dr.dataSync() ); dr = undefined; + const SIN = sin.dataSync(); + Object.freeze(dA_shape); + + let l = SIN.length; + for( let R_off=R.length, + Q_off=Q.length; Q_off > 0; ) + { + Q_off -= M*M, + R_off -= M*N + + // TRANSPOSE Q (for cache locality) + for( let i=0; i < M; i++ ) + for( let j=0; j < i; j++ ) { + const Q_ij = Q[Q_off + M*i+j]; + Q[Q_off + M*i+j] = Q[Q_off + M*j+i]; + Q[Q_off + M*j+i] = Q_ij; + } + + // TRANSPOSE dQ (for cache locality) + for( let i=0; i < M; i++ ) + for( let j=0; j < i; j++ ) { + const dQ_ij = dQ[Q_off + M*i+j]; + dQ[Q_off + M*i+j] = dQ[Q_off + M*j+i]; + dQ[Q_off + M*j+i] = dQ_ij; + } + + // BEGIN QR DECOMPOSITION + for( let i=M; --i > 0; ) { const J = Math.min(i,N); + for( let j=J; j-- > 0; ) + { + // DETERMINE GIVENS ROTATION cos AND sin + const s = SIN[--l]; if( 0 == s ) continue; + const c = Math.sqrt((1+s)*(1-s)), + norm = R[R_off + N*j+j]; + + // ROTATE ROWS IN R + for( let k=j; k < N; k++ ) + { const R_jk = R[R_off + N*j+k], + R_ik = R[R_off + N*i+k]; + R[R_off + N*j+k] = c*R_jk - s*R_ik; + R[R_off + N*i+k] = s*R_jk + c*R_ik; + } + + // ROTATE ROWS IN Qᵀ + for( let k=0; k <= i; k++ ) + { const Q_jk = Q[Q_off + M*j+k], + Q_ik = Q[Q_off + M*i+k]; + Q[Q_off + M*j+k] = c*Q_jk - s*Q_ik; + Q[Q_off + M*i+k] = s*Q_jk + c*Q_ik; + } + + const R_ij = R[R_off + N*i+j], + R_jj = R[R_off + N*j+j], + dc_dj = + R_ij / norm * R_ij / norm**2, + dc_di = - R_ij / norm * R_jj / norm**2, + ds_dj = - R_jj / norm * R_ij / norm**2, + ds_di = + R_jj / norm * R_jj / norm**2; + let dj = 0.0, + di = 0.0; + + // ROTATE ROWS IN dR + for( let k=j; k < N; k++ ) + { const dR_jk = dR[R_off + N*j+k], + dR_ik = dR[R_off + N*i+k]; + dR[R_off + N*j+k] = c*dR_jk - s*dR_ik; + dR[R_off + N*i+k] = s*dR_jk + c*dR_ik; + + const R_jk = R[R_off + N*j+k], + R_ik = R[R_off + N*i+k]; + + dj += dR_jk*(R_ik*ds_dj + R_jk*dc_dj) + dR_ik*(R_ik*dc_dj - R_jk*ds_dj); + di += dR_jk*(R_ik*ds_di + R_jk*dc_di) + dR_ik*(R_ik*dc_di - R_jk*ds_di); + } + + // ROTATE ROWS IN dQᵀ + for( let k=0; k <= i; k++ ) + { const dQ_jk = dQ[Q_off + M*j+k], + dQ_ik = dQ[Q_off + M*i+k]; + dQ[Q_off + M*j+k] = c*dQ_jk - s*dQ_ik; + dQ[Q_off + M*i+k] = s*dQ_jk + c*dQ_ik; + + const Q_jk = Q[Q_off + M*j+k], + Q_ik = Q[Q_off + M*i+k]; + + dj += dQ_jk*(Q_ik*ds_dj + Q_jk*dc_dj) + dQ_ik*(Q_ik*dc_dj - Q_jk*ds_dj); + di += dQ_jk*(Q_ik*ds_di + Q_jk*dc_di) + dQ_ik*(Q_ik*dc_di - Q_jk*ds_di); + } + + dR[R_off + N*j+j] += dj; + dR[R_off + N*i+j] += di; + }} // END QR DECOMPOSITION + } + assert( l == 0, `WTF: ${l} != 0` ); + + return Tensor.make(dA_shape,{values: dR},DType); +} + + + +/** Returns a copy of a tensor of matrices with a different main diagonal. + * + * @param a Tensor of shape [..., M,N ]. + * @param d Tensor of shape [...,min(M,N)]. + * + * @returns Tensor of shape [..., M,N ]. A new tensor comprised of the off-diagonal + * entries of a and the main diagonal set to the entries of d. + */ +/** @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function setDiag_( a: Tensor|TensorLike, d: Tensor|TensorLike ): Tensor +{ + let $a = convertToTensor(a,'a','setDiag'); if( $a.rank < 2 ) throw new Error(`setDiag(): a.rank=${$a.rank} < 2`); + let $d = convertToTensor(d,'d','setDiag'); if( $d.rank < 1 ) throw new Error(`setDiag(): d.rank=${$d.rank} < 1`); + + const dtype = upcastType($a.dtype, $d.dtype); + if( $a.dtype != dtype ) $a = $a.cast(dtype); + if( $d.dtype != dtype ) $d = $d.cast(dtype); + + const rank: number = Math.max($a.rank-1, $d.rank), + shape: number[] = new Array(rank); + + if( $d.shape[$d.rank-1] != Math.min( ...$a.shape.slice(-2) ) ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${$a.shape}] [${$d.shape}]`) + + // FIND COMMON (BROADCASTED) SHAPE + for( let i=$a.rank-2, + j=$d.rank-1, + k= rank-1; i > 0 || j > 0; ) + { + i--; j--; k--; + if( 1 === $a.shape[i] ) + shape[k] = $d.shape[j] || 1; + else if( $a.shape[i] != $d.shape[j] && $d.shape[j] != 1 ) + throw new Error(`setDiag(): Incompatible shapes for a and d [${$a.shape}] [${$d.shape}]`); + else + shape[k] = $a.shape[i] || 1; + } + + shape[shape.length-1] = $d.shape[$d.rank-1]; $d = broadcastTo($d,shape); + shape[shape.length-1] = $a.shape[$a.rank-2]; + shape .push($a.shape[$a.rank-1]); $a = broadcastTo($a,shape); + Object.freeze(shape); + const $d_shape = $d.shape; + + return ENV.engine.runKernel( + backend => backend.matrixSetDiag($a,$d), + {$a,$d}, + dy => ({ + $a: () => setDiag(dy, zeros($d_shape) ), + $d: () => diagPart(dy) + }) + ); +} + + + +/** Returns the main diagonals from a tensor of matrices. + * The result is a tensor of a rank one less than the + * input tensor. + * + * @param a Tensor of shape [..., M,N]. The tensor whose main diagonal is returned. + * + * @returns Tensor of shape [...,min(M,N)]. The main diagonal of `a`. + */ +/** @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function diagPart_( a: Tensor|TensorLike ): Tensor +{ + const $a = convertToTensor(a,'a','diagPart'), + $a_shape = $a.shape; + + if( $a.rank < 2 ) throw new Error('diagPart(): Input a.rank must be at least 2.'); + if( $a.shape.some( d => d < 0 || + d%1 !== 0 ) ) throw new Error(`diagPart(): Invalid input shape [${$a.shape}].`); + + return ENV.engine.runKernel( + backend => backend.matrixDiagPart($a), + {$a}, + dy => ({ + $a: () => setDiag( zeros($a_shape), dy ) + }) + ); +} + + + +/** Copies a tensor of matrices, setting everything outside a central band + * in each matrix to zero. + */ +/** @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function bandPart_( a: T|TensorLike, numLower: number, numUpper: number ): T +{ + const $a = convertToTensor(a,'a','bandPart'); + if( numLower%1 != 0 ) throw new Error(`bandPart(): numLower=${numLower} is no integer.`); + if( numUpper%1 != 0 ) throw new Error(`bandPart(): numUpper=${numUpper} is no integer.`); + if( !(numLower <= $a.shape[$a.rank-2]) ) throw new Error(`bandPart() assertion failed: numLower <= nRows.`); + if( !(numUpper <= $a.shape[$a.rank-1]) ) throw new Error(`bandPart() assertion failed: numUpper <= nCols.`); + if( numLower < 0 ) numLower = $a.shape[$a.rank-2]; + if( numUpper < 0 ) numUpper = $a.shape[$a.rank-1]; + + return ENV.engine.runKernel( + backend => backend.matrixBandPart($a,numLower,numUpper), + {$a}, + (dy: T) => ({ + $a: () => bandPart(dy, numLower, numUpper) + }) + ); +} + + +/** Conjugates a tensor of matrices and then transposes the last two dimensions. + * The adjoint is also commonly known as the Hermitian Transpose. + * + * @param a Tensor of shape [...,M,N]. The tensor of matrices that is to be tranposed. + * + * @returns Tensor of shape [...,N,M]. The transpose of `a`. + */ +/** @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function adjoint_( a: T|TensorLike ): T +{ + const $a = convertToTensor(a,'a','bandPart'); + + const axes = Array.from( $a.shape, (_,i) => i ); + axes[axes.length-2] = axes.length-1; + axes[axes.length-1] = axes.length-2; + + const $a_T = $a.transpose(axes); + if( $a_T.dtype.startsWith('complex') ) + return conj($a_T) + return $a_T; +} + + +function choleskyKernel( a: Tensor ): T +{ + if( ! a.dtype.startsWith('float') ) throw new Error(`cholesky(): a.dtype=${a.dtype} not supported.`); + if( a.rank < 2 ) throw new Error(`cholesky(): a.rank={a.rank} < 2.`); + + const + dtype = a.dtype, + shape = a.shape, + [N,M] = shape.slice(-2), + L = a.dataSync().slice(); + a = undefined; + + if( N != M ) throw new Error('cholesky(): Last two dimensions must be square.') + + for( let off=0; off < L.length; off += N*N ) + // https://de.wikipedia.org/wiki/Cholesky-Zerlegung + for( let i=0; i j ) L[off + N*i+j] = sum / L[off + N*j+j]; + else { L[off + N*i+i] = Math.sqrt(sum); + if( isNaN(L[off + N*i+i]) ) + throw new Error('cholesky(): a contains NaNs or (near) negative semi-definite.'); + } + } + + return Tensor.make(shape,{values: L},dtype); +} + + +/** Computes the cholesky decomposition of a tensor of symmetric matrices. + * + * @param a Tensor of shape [...,N,N]. A tensor of symmetric matrices, for which the + * cholesky decomposition is computed. + * + * @returns Tensor of shape [...,N,N]. A tensor of lower triangular matrices `L` such that `L∙Lᵀ=A`. + */ +/** @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function cholesky_( a: T|TensorLike ): T +{ + const $a = convertToTensor(a,'a','cholesky'); + + // WHERE THE BACKPROP COMES FROM (⊗ is the Hadamard Product, i.e. elementwise multiplication): + // SEE: https://arxiv.org/pdf/1602.07527.pdf#page=3 + // + // some preliminaries: + // (1) tr(A∙◣) = tr( triu(A) ∙ ◣ ) => the strict upper triangle can be chosen at will + // (2) tr(A∙◥) = tr( tril(A) ∙ ◥ ) => the strict lower triangle can be chosen at will + // + // now for the pertubation: + // A = L∙Lᵀ + // => dA = dL∙Lᵀ + L∙dLᵀ + // => L⁻¹∙dA∙L⁻ᵀ = L⁻¹∙dL + dLᵀ∙L⁻ᵀ (= ◣ + ◥) + // => L⁻¹∙dL = tril( L⁻¹∙dA∙L⁻ᵀ - ½diag(L⁻¹∙dA∙L⁻ᵀ) ) + // => dL = L∙tril( L⁻¹∙dA∙L⁻ᵀ - ½diag(L⁻¹∙dA∙L⁻ᵀ) ) + // => dL = L∙Φ(L⁻¹∙dA∙L⁻ᵀ) + // where: + // ⎧ i>j: X[i,j] + // Φ(X)[i,j] := ⎨ i=j: X[i,j]/2 + // ⎩ i df = tr(ℒᵀ∙dL) + // = 0.5*tr(ℒᵀ∙L∙L⁻¹∙dL) + 0.5*tr(Lᵀ∙ℒ∙dLᵀ∙L⁻ᵀ) ( = 0.5*tr(ℒᵀ∙L ∙ ◣) + 0.5*tr(Lᵀ∙ℒ ∙ ◥) ) + // + // (1) & (2) => df = 0.5*tr( { Φ(Lᵀ∙ℒ) + Φᵀ(Lᵀ∙ℒ) } ∙ { L⁻¹∙dL + dLᵀ∙L⁻ᵀ } ) + // = 0.5*tr( { Φ(Lᵀ∙ℒ) + Φᵀ(Lᵀ∙ℒ) } ∙ L⁻¹ ∙ dA ∙ L⁻ᵀ ) + // = 0.5*tr( L⁻ᵀ ∙ { Φ(Lᵀ∙ℒ) + Φᵀ(Lᵀ∙ℒ) } ∙ L⁻¹ ∙ dA ) + // = 0.5*tr( L⁻¹ ∙ { Φ(Lᵀ∙ℒ) + Φᵀ(Lᵀ∙ℒ) } ∙ L⁻ᵀ ∙ dAᵀ ) + // = 0.5*tr( L⁻ᵀ ∙ { Φ(Lᵀ∙ℒ) + Φᵀ(Lᵀ∙ℒ) } ∙ L⁻¹ ∙ dAᵀ ) + // + // => ∂f/∂A = 0.5 * L⁻ᵀ∙{Φ (Lᵀ∙ℒ) + Φᵀ(Lᵀ∙ℒ)} ∙ L⁻¹ + // = 0.5 * L⁻ᵀ∙ Φ (Lᵀ∙ℒ) ∙ L⁻¹ + // + 0.5 *(L⁻ᵀ∙ Φᵀ(Lᵀ∙ℒ) ∙ L⁻¹)ᵀ + // = 0.5 * L⁻ᵀ∙ { (Lᵀ∙ℒ)ᵀ + tril(Lᵀ∙ℒ) } ∙ L⁻¹ + + return ENV.engine.runKernel( + (backend,saveFn) => saveFn(/*L=*/choleskyKernel($a)), + {$a}, + (dL,[L]) => ({ + $a: () => { + // TODO: is tidy required here? +// dL = bandPart(dL,-1,0); + let dA = matMul(L,dL,/*adjoint_L=*/true); + dA = bandPart(dA,-1,0); + const diag0 = zeros( dA.shape.slice(0,-1) ); + dA = setDiag(dA,diag0).add( adjoint(dA) ); + dA = triangularSolve(L,dA,/*lower=*/true,/*adjoint_L=*/true); dA = adjoint(dA); + dA = triangularSolve(L,dA,/*lower=*/true,/*adjoint_L=*/true); + const HALF = scalar(0.5); + dA = setDiag( dA, diagPart(dA).mul(HALF) ); // <- FIXME: where is this HALF coming from ??? + dA = bandPart(dA,-1,0); + return dA as T; + } + }) + ); +} + + +/** Given a lower triangular matrix l and a right-hand-side y, this method solves + * the LES `L∙Lᵀ∙X = Y`. * - * ```js - * const a = tf.tensor2d([[1, 2], [3, 4]]); - * let [q, r] = tf.linalg.qr(a); - * console.log('Q'); - * q.print(); - * console.log('R'); - * r.print(); - * console.log('Orthogonalized'); - * q.dot(q.transpose()).print() // should be nearly the identity matrix. - * console.log('Reconstructed'); - * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]]; - * ``` + * @param l Tensor of shape [...,N,N]. A tensor of lower triangular matrices. + * @param y Tensor of shape [...,N,M]. A tensor of right-hand-side matrices. * - * @param x The `Tensor` to be QR-decomposed. Must have rank >= 2. Suppose - * it has the shape `[..., M, N]`. - * @param fullMatrices An optional boolean parameter. Defaults to `false`. - * If `true`, compute full-sized `Q`. If `false` (the default), - * compute only the leading N columns of `Q` and `R`. - * @returns An `Array` of two `Tensor`s: `[Q, R]`. `Q` is a unitary matrix, - * i.e., its columns all have unit norm and are mutually orthogonal. - * If `M >= N`, - * If `fullMatrices` is `false` (default), - * - `Q` has a shape of `[..., M, N]`, - * - `R` has a shape of `[..., N, N]`. - * If `fullMatrices` is `true` (default), - * - `Q` has a shape of `[..., M, M]`, - * - `R` has a shape of `[..., M, N]`. - * If `M < N`, - * - `Q` has a shape of `[..., M, M]`, - * - `R` has a shape of `[..., M, N]`. - * @throws If the rank of `x` is less than 2. + * @returns Tensor of shape [...,N,M]. The solution `X` of the LES `L∙Lᵀ∙X = Y`. */ +/** @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function choleskySolve_( l: Tensor, y: Tensor ) +{ + let x = triangularSolve(l,y,/*lower=*/true,/*adjoint_L=*/false) + x = triangularSolve(l,y,/*lower=*/true,/*adjoint_L=*/true ) + return x; +} + + +/** Solves a triangular linear equation system (LES). + * + * @param l The triangular matrix of the . + * @param y The right-hand-side of the LES. + * @param lower If set to `true`, `l` is interpreted as lower triangular matrix. + * The strict upper triangular entries are ignore. If set to `false`, + * `l` is interpreted as upper triangular matrix and the strict lower + * triangular entries are ignored. + * @param adjoint If set to `true`, the hermitian transpose of `l` is used in the LES. + * + * @returns The solution of one of the following LES: + *
+ *
lower=false, adjoint=false
tril(l) ∙x == y + *
lower=true, adjoint=false
triu(l) ∙x == y + *
lower=false, adjoint=true
tril(l)ᴴ∙x == y + *
lower=true, adjoint=true
triu(l)ᴴ∙x == y + *
+ */ +/** @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function triangularSolve_( l: Tensor|TensorLike, y: Tensor|TensorLike, lower=true, adjoint=false ): Tensor +{ + // FIXME: if `l` is singular the right hand side could be checked for 0 and then some solution could be used + let [$l,$y] = broadcastMatrices( + convertToTensor(l,'l','triangularSolve'), + convertToTensor(y,'y','triangularSolve') + ); + l=undefined; + y=undefined; + if( $l.rank < 2 ) throw new Error(`triangularSolve(): l.rank must be at least 2.`); + if( $y.rank < 2 ) throw new Error(`triangularSolve(): y.rank must be at least 2.`); + + const dtype = upcastType($l.dtype, $y.dtype); + if( $l.dtype != dtype ) $l = $l.cast(dtype); + if( $y.dtype != dtype ) $y = $y.cast(dtype); + + // WHERE THE BACKPROP COMES FROM: + // x = L⁻¹∙y + // => dx = d(L⁻¹)∙y + L⁻¹∙dy = L⁻¹∙dy - L⁻¹∙dL∙L⁻¹∙y = L⁻¹∙dy - L⁻¹∙dL∙x + // => df = tr( (∂f/∂x)∙dxᵀ ) + // = tr( (∂f/∂x)∙dyᵀ∙L⁻ᵀ ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙dLᵀ∙L⁻ᵀ ) + // = tr( (∂f/∂x)ᵀ∙L⁻¹∙dy ) - tr( (∂f/∂x)∙yᵀ∙L⁻ᵀ∙(L⁻¹∙dL)ᵀ ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻¹∙y∙(∂f/∂x)ᵀ∙ L⁻¹∙dL ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( x∙(∂f/∂x)ᵀ∙ L⁻¹∙dL ) + // = tr( L⁻ᵀ∙(∂f/∂x) ∙dyᵀ) - tr( L⁻ᵀ ∙(∂f/∂x) ∙ xᵀ ∙dLᵀ ) + // => ∂f/∂y = L⁻ᵀ∙(∂f/∂x) + // ∂f/∂L = -L⁻ᵀ∙(∂f/∂x)∙xᵀ = ∂f/∂L = -(∂f/∂y)∙xᵀ + + // SEE: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L218 + return ENV.engine.runKernel( + (backend,saveFn) => { + const x = triangularSolveKernel($l,$y,lower,adjoint); + saveFn(x); + return x; + }, + {$l,$y}, + (dx,[x]) => { + const dy = triangularSolve($l, dx, lower, !adjoint); + return { + $l: () => { + let dl = adjoint ? matMul( x, dy, false, true) + : matMul(dy, x, false, true); + dl = dl.neg(); + dl = lower ? bandPart(dl,-1, 0) + : bandPart(dl, 0,-1); + return dl; + }, + $y: () => dy + }; + } + ); +} + + +function triangularSolveKernel( l: Tensor, y: Tensor, lower: boolean, adjoint: boolean ): Tensor +{ + if( ! l.dtype.startsWith('float') ) throw new Error(`triangularSolve(): l.dtype=${l.dtype} not supported.`); + if( ! y.dtype.startsWith('float') ) throw new Error(`triangularSolve(): y.dtype=${y.dtype} not supported.`); + if( l.rank < 2 ) throw new Error('triangularSolve(): l must be at least 2D.'); + if( y.rank < 2 ) throw new Error('triangularSolve(): y must be at least 2D.'); + if( l.rank != y.rank ) throw new Error('triangularSolve(): l and y must have same rank.'); + for( let i=l.rank-2; i-- > 0; ) + if( l.shape[i] != y.shape[i] ) throw new Error('triangularSolve(): leading dimensions do not match.'); + + const [N,M] = l.shape.slice(-2), + [I,J] = y.shape.slice(-2); + if( N != M ) throw new Error('triangularSolve(): Last two axes of L not square.'); + if( I != M ) throw new Error('triangularSolve(): L and y do not match.'); + + const + rank = Math.max(l.rank, y.rank), + X_shape = Array.from(l.shape); + X_shape[rank-2] = I; + X_shape[rank-1] = J; + + // GENERATE RESULT DATA + const + dtype = 'float32', +// dtype = ( l.dtype === 'float64' || +// y.dtype === 'float64' ) ? 'float64' : 'float32', + DTypeArray = dtype === 'float32' ? Float32Array + : Float64Array, + L = l.dataSync(), + X = DTypeArray.from( y.dataSync() ) as TypedArray; + let L_off = 0, + X_off = 0; + l = undefined; + y = undefined; + + function solv( d: number ): void { + if( d === rank-2 ) { + if( ! adjoint ) + { + if(lower) // FORWARD SUBSTITUTION + for( let i=0; i < I; i++ ) { + for( let k=0; k < i; k++ ) + for( let j=0; j < J; j++ ) X[X_off + J*i+j] -= L[L_off + N*i+k] * X[X_off + J*k+j] + + for( let j=0; j < J; j++ ) X[X_off + J*i+j] /= L[L_off + N*i+i] + } + else // BACKWARD SUBSTITUTION + for( let i=I; i-- > 0; ) { + for( let j=J; j-- > 0; ) X[X_off + J*i+j] /= L[L_off + N*i+i] + + for( let k=i; k-- > 0; ) + for( let j=J; j-- > 0; ) X[X_off + J*k+j] -= L[L_off + N*k+i] * X[X_off + J*i+j] + } + } + else + { + if(lower) // BACKWARD SUBSTITUTION (TRANSPOSED) + for( let i=I; i-- > 0; ) { + for( let j=J; j-- > 0; ) X[X_off + J*i+j] /= L[L_off + N*i+i] + + for( let k=i; k-- > 0; ) + for( let j=J; j-- > 0; ) X[X_off + J*k+j] -= L[L_off + N*i+k] * X[X_off + J*i+j] + } + else // FORWARD SUBSTITUTION (TRANSPOSED) + for( let i=0; i < I; i++ ) { + for( let k=0; k < i; k++ ) + for( let j=0; j < J; j++ ) X[X_off + J*i+j] -= L[L_off + N*k+i] * X[X_off + J*k+j] + + for( let j=0; j < J; j++ ) X[X_off + J*i+j] /= L[L_off + N*i+i] + } + } + + L_off += N*N; + X_off += N*J; + + return; + } + for( let l=X_shape[d]; l-- > 0; ) + solv(d+1); + } + solv(0); + + return Tensor.make(X_shape,{values: X},dtype); +} + + +/** Compute QR decomposition of m-by-n matrix using Givens rotations. + * + * Implementation based on + * [http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf] + * (http://www.cs.cornell.edu/~bindel/class/cs6210-f09/lec18.pdf) + * + * ```js + * const a = tf.tensor2d([[1, 2], [3, 4]]); + * let [q, r] = tf.linalg.qr(a); + * console.log('Q'); + * q.print(); + * console.log('R'); + * r.print(); + * console.log('Orthogonalized'); + * q.dot(q.transpose()).print() // should be nearly the identity matrix. + * console.log('Reconstructed'); + * q.dot(r).print(); // should be nearly [[1, 2], [3, 4]]; + * ``` + * + * @param x The `Tensor` to be QR-decomposed. Must have rank >= 2. Suppose + * it has the shape `[..., M, N]`. + * @param fullMatrices An optional boolean parameter. Defaults to `false`. + * If `true`, compute full-sized `Q`. If `false` (the default), + * compute only the leading N columns of `Q` and `R`. + * @returns An `Array` of two `Tensor`s: `[Q, R]`. `Q` is a unitary matrix, + * i.e., its columns all have unit norm and are mutually orthogonal. + * If `M >= N`, + * If `fullMatrices` is `false` (default), + * - `Q` has a shape of `[..., M, N]`, + * - `R` has a shape of `[..., N, N]`. + * If `fullMatrices` is `true` (default), + * - `Q` has a shape of `[..., M, M]`, + * - `R` has a shape of `[..., M, N]`. + * If `M < N`, + * - `Q` has a shape of `[..., M, M]`, + * - `R` has a shape of `[..., M, N]`. + * @throws If the rank of `x` is less than 2. + */ /** * @doc {heading:'Operations', * subheading:'Linear Algebra', * namespace:'linalg'} */ -function qr_(x: Tensor, fullMatrices = false): [Tensor, Tensor] { - if (x.rank < 2) { - throw new Error( - `qr() requires input tensor to have a rank >= 2, but got rank ${ - x.rank}`); - } else if (x.rank === 2) { - return qr2d(x as Tensor2D, fullMatrices); - } else { - // Rank > 2. - // TODO(cais): Below we split the input into individual 2D tensors, - // perform QR decomposition on them and then stack the results back - // together. We should explore whether this can be parallelized. - const outerDimsProd = x.shape.slice(0, x.shape.length - 2) - .reduce((value, prev) => value * prev); - const x2ds = unstack( - x.reshape([ - outerDimsProd, x.shape[x.shape.length - 2], - x.shape[x.shape.length - 1] - ]), - 0); - const q2ds: Tensor2D[] = []; - const r2ds: Tensor2D[] = []; - x2ds.forEach(x2d => { - const [q2d, r2d] = qr2d(x2d as Tensor2D, fullMatrices); - q2ds.push(q2d); - r2ds.push(r2d); - }); - const q = stack(q2ds, 0).reshape(x.shape); - const r = stack(r2ds, 0).reshape(x.shape); - return [q, r]; +function qr_( a: Tensor, fullMatrices = false ): [Tensor, Tensor] { + if( a.rank < 2 ) + throw new Error(`qr() requires input tensor to have a rank >= 2, but got rank ${a.rank}`); + if( a.dtype.startsWith('complex') ) + throw new Error(`qr() not yet supported for complex tensors.`); + + const [m,n] = a.shape.slice(-2) + + if( m == n || m > n && !fullMatrices ) + { + // FIXME: What if R is singular? + return ENV.engine.runKernel( + (backend,saveFunc) => { + const [q,r] = qr_eco_decomp_(a); + saveFunc(q); + saveFunc(r); + return [q,r]; + }, + { $a: a }, + ([dq,dr], [q,r]) => ({ + $a: () => { + // TODO: is tidy required here? + // https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/linalg_grad.py#L95 + const qdq = matMul(q,dq, true, false), qdq_ = qdq.sub( adjoint_(qdq) ), + rdr = matMul(r,dr, false, true), rdr_ = rdr.sub( adjoint_(rdr) ), + tril = bandPart( add(qdq_,rdr_), -1, 0 ); + + const triSolv = (x: Tensor,r: Tensor) => adjoint_( + triangularSolve(r, adjoint(x), /*lower=*/false, /*adjoint_r*/false) + ); + + const grad_a = matMul( q, dr.add( triSolv(tril,r) ) ), + grad_b = triSolv( dq.sub( matMul(q,qdq) ), r ); + + return add(grad_a,grad_b); + } + }) + ) as [Tensor, Tensor]; + } + + let [q,r] = ENV.engine.runKernel( + (backend,saveFunc) => { + const [q,r,sin] = qr_full_decomp_(a); + saveFunc(q); + saveFunc(r); + saveFunc(sin); + return [q,r]; + }, + { a: a }, + ([dq,dr], [q,r,sin]) => ({ + a: () => ENV.engine.runKernel( + (backend,saveFunc) => qr_full_backprop_(q,dq, r,dr, sin), + { $dq: dq, $dr: dr } + ) + }) + ); + + if( ! fullMatrices && m > n ) { + let end = a.shape.slice(); + q = q.slice([0, 0], end); end[end.length-2] = n; + r = r.slice([0, 0], end); } + + return [q,r]; } -function qr2d(x: Tensor2D, fullMatrices = false): [Tensor2D, Tensor2D] { - return ENV.engine.tidy(() => { - if (x.shape.length !== 2) { - throw new Error( - `qr2d() requires a 2D Tensor, but got a ${x.shape.length}D Tensor.`); - } +/** Applies the inverse of a row permutations, given by list of indices, to a Tensor. + * + * @param a Tensor of shape [...,M,N]. + * @param p Tensor of shape [...,M]. + * + * @returns b Tensor of shape [...,M,N], such that: `b[...,p[...,i],j] == a[...,i,j]`. + */ +function permuteRowsInv_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this? +{ + if( a.rank != p.rank+1 ) throw new Error(`permuteRows(): a.rank and p.rank do not match.`); + for( let i=p.rank; i-- > 0; ) + if( a.shape[i] != p.shape[i] ) + throw new Error(`permuteRows(): a.shape and p.shape incompatible.`); - const m = x.shape[0]; - const n = x.shape[1]; - - let q = eye(m) as Tensor2D; // Orthogonal transform so far. - let r = x.clone(); // Transformed matrix so far. - - const one2D = tensor2d([[1]], [1, 1]); - let w: Tensor2D = one2D.clone(); - - const iters = m >= n ? n : m; - for (let j = 0; j < iters; ++j) { - // This tidy within the for-loop ensures we clean up temporary - // tensors as soon as they are no longer needed. - const rTemp = r; - const wTemp = w; - const qTemp = q; - [w, r, q] = ENV.engine.tidy((): [Tensor2D, Tensor2D, Tensor2D] => { - // Find H = I - tau * w * w', to put zeros below R(j, j). - const rjEnd1 = r.slice([j, j], [m - j, 1]); - const normX = rjEnd1.norm(); - const rjj = r.slice([j, j], [1, 1]); - const s = rjj.sign().neg() as Tensor2D; - const u1 = rjj.sub(s.mul(normX)) as Tensor2D; - const wPre = rjEnd1.div(u1); - if (wPre.shape[0] === 1) { - w = one2D.clone(); - } else { - w = one2D.concat( - wPre.slice([1, 0], [wPre.shape[0] - 1, wPre.shape[1]]), 0) as - Tensor2D; - } - const tau = s.matMul(u1).div(normX).neg() as Tensor2D; - - // -- R := HR, Q := QH. - const rjEndAll = r.slice([j, 0], [m - j, n]); - const tauTimesW = tau.mul(w) as Tensor2D; - if (j === 0) { - r = rjEndAll.sub(tauTimesW.matMul(w.transpose().matMul(rjEndAll))); - } else { - r = r.slice([0, 0], [j, n]) - .concat( - rjEndAll.sub( - tauTimesW.matMul(w.transpose().matMul(rjEndAll))), - 0) as Tensor2D; - } - const qAllJEnd = q.slice([0, j], [m, q.shape[1] - j]); - if (j === 0) { - q = qAllJEnd.sub(qAllJEnd.matMul(w).matMul(tauTimesW.transpose())); - } else { - q = q.slice([0, 0], [m, j]) - .concat( - qAllJEnd.sub( - qAllJEnd.matMul(w).matMul(tauTimesW.transpose())), - 1) as Tensor2D; - } - return [w, r, q]; - }); - dispose([rTemp, wTemp, qTemp]); - } + const [N] = p.shape.slice(-1), + P = p.dataSync().slice(), + P_inv = new Int32Array(P.length); - if (!fullMatrices && m > n) { - q = q.slice([0, 0], [m, n]); - r = r.slice([0, 0], [n, n]); - } + for( let off=0; off < P.length; off += N ) + for( let i =0; i < N; i++ ) { + P_inv[off + P[off+i]] = off+i; + } + + return gather( + a.reshape( [-1].concat( a.shape.slice(-1) ) ), + P_inv + ).reshape( a.shape ) +} + +/** Applies a row permutations, given by list of indices, to a Tensor. + * + * @param a Tensor of shape [...,M,N]. + * @param p Tensor of shape [...,M]. + * + * @returns b Tensor of shape [...,M,N], such that: `b[...,i,j] == a[...,p[...,i],j]`. + */ +function permuteRows_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this? +{ + if( a.rank != p.rank+1 ) throw new Error(`permuteRows(): a.rank and p.rank do not match.`); + for( let i=p.rank; i-- > 0; ) + if( a.shape[i] != p.shape[i] ) + throw new Error(`permuteRows(): a.shape and p.shape incompatible.`); + + const [N] = p.shape.slice(-1); + + p = p.reshape( [-1].concat( p.shape.slice(-1) ) ); + p = p.add( + range(0, p.shape[0]*N, N, 'int32').reshape([-1,1]) + ); - return [q, r]; - }) as [Tensor2D, Tensor2D]; + return a.reshape( [-1].concat(a.shape.slice(-1)) ) + .gather( p.flatten() ) + .reshape( a.shape ); +} + +/** Computes the LU decomposition. + * + * @param a Tensor[...,N,N]. + * + * @returns [lu: Tensor[...,N,N], p: Tensor[...N]] Where `(L @ U)[i,j] = a[...,p[...,i],j]` + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function lu_( a: Tensor, permute=true ): [Tensor, Tensor] +{ + // dA = dL∙U + L∙dU + // = L⁻¹∙dA∙U⁻¹ = L⁻¹∙dL + dU∙U⁻¹ = strict(◣) + ◥ + // => dL = L∙strict_tril(L⁻¹∙dA∙U⁻¹) + // => dU = triu(L⁻¹∙dA∙U⁻¹)∙U + // => d(LU) = dL + dU + // ℒ := ∂f/∂(LU) = strict(◣) + // df = tr( ℒ∙d(LU)ᵀ ) + // = tr( ℒ∙dLᵀ ) + tr( ℒ∙dUᵀ ) + // = tr( Lᵀ∙ℒ∙{strict_tril(L⁻¹∙dA∙U⁻¹)}ᵀ ) + tr( ℒ∙Uᵀ∙{triu(L⁻¹∙dA∙U⁻¹)}ᵀ ) + // = tr( { strict_tril(Lᵀ∙ℒ) + triu(ℒ∙Uᵀ) } ∙ (L⁻¹∙dA∙U⁻¹)ᵀ ) + // = tr( L⁻ᵀ∙{ strict_tril(Lᵀ∙ℒ) + triu(ℒ∙Uᵀ) }∙U⁻ᵀ ∙ dAᵀ ) + // => ∂f/∂A = L⁻ᵀ∙{ strict_tril(Lᵀ∙ℒ) + triu(ℒ∙Uᵀ) }∙U⁻ᵀ + + const $a = convertToTensor(a,'a','lu'); + + if( ! permute ) + throw new Error('lu(): permute=false not yet implemented.'); + + return ENV.engine.runKernel( + (backend,saveFunc) => { + const [lu,p] = lu_p_decomp_($a); + saveFunc(lu); + saveFunc(p ); + return [lu,p]; + }, + { $a: $a }, + ([dLU,dP],[lu,p]) => ({ + $a: () => { + const diag0 = zeros(p.shape), + diag1 = ones(p.shape); + + const l = setDiag( bandPart(lu,-1, 0), diag1 ), + u = bandPart(lu, 0,-1); + + let ℒL = matMul( dLU,l, true, false); ℒL = bandPart(ℒL, 0,-1); ℒL = setDiag(ℒL, diag0); + let Uℒ = matMul(u,dLU, false, true ); Uℒ = bandPart(Uℒ,-1, 0); + + let dA = add(ℒL,Uℒ) + dA = triangularSolve(u, dA, /*lower=*/false); dA = adjoint(dA); + dA = triangularSolve(l, dA, /*lower=*/true, /*adjoint_l=*/true); + return permuteRowsInv_(dA,p); + } + }) + ) as [Tensor, Tensor]; +} + +/** Solves a linear equation system (LES) using its LU decomposition. + * + * @param lu Tensor of shape [...,N,N]. Thensor containing the data of L and U, + * the triangular matrices resulting from the LU Decomposition. + * @param p Tensor of shape [...,N]. The permutation indices from the LU Decomposition. + * @param y Tensor of shape [...,N,M]. The right-hand-side of the LES. + */ +/** + * @doc {heading:'Operations', + * subheading:'Linear Algebra', + * namespace:'linalg'} + */ +function luSolve_( lu: Tensor|TensorLike, p: Tensor|TensorLike, y: Tensor|TensorLike = undefined ): Tensor +{ + let $lu = convertToTensor(lu,'lu','luSolve'); lu = undefined; + + const [n] = $lu.shape.slice(-1); + + let $p: Tensor, + $y: Tensor; + + if( null == y ) { + $y = convertToTensor(p, 'p', 'luSolve'); + $p = range(0,n,1,'int32'); + } else { + $p = convertToTensor(p, 'p', 'luSolve'); + $y = convertToTensor(y, 'y', 'luSolve'); + } + lu = p = y = null; + + if( $lu.rank < 2 ) throw new Error('lu(): lu.rank must be at least 2.'); + + if( $lu.shape[$lu.rank-1] != $lu.shape[$lu.rank-2] ) throw new Error('lu(): lu must be square.'); + if( $lu.shape[$lu.rank-1] != $p.shape[$p .rank-1] ) throw new Error('lu(): lu and p of incompatible shape.'); + if( $lu.shape[$lu.rank-1] != $y.shape[$y .rank-2] ) throw new Error('lu(): lu and y of incompatible shape.'); + + const rank = Math.max($lu.rank, $p.rank, $y.rank); + const shape = $p.shape.slice(0,-1); + while( shape.length < rank-2 ) + shape.unshift(1); + + // FIND COMMON (BROADCASTED) SHAPE + for( const tensor of [$lu,$y] ) + for( let i=rank-2, j=tensor.rank-2; i-- > 0 && j-- > 0; ) + if( 1 === shape[i] ) + shape[i] = tensor.shape[j]; + else if( shape[i] != tensor.shape[j] && tensor.shape[j] != 1 ) + throw new Error(`lu(): Shapes not broadcast-compatible.`); + + $lu = broadcastTo( $lu, shape.concat($lu.shape.slice(-2)) ) + $p = broadcastTo( $p , shape.concat($p .shape.slice(-1)) ) + $y = broadcastTo( $y , shape.concat($y .shape.slice(-2)) ) + + const l = setDiag( $lu, ones($p.shape) ) + + let x = permuteRows_( $y, $p ); + x = triangularSolve( l ,x,/*lower=*/true ); + x = triangularSolve($lu,x,/*lower=*/false); + return x; } export const gramSchmidt = op({gramSchmidt_}); + +export const adjoint = op({adjoint_}); +export const setDiag = op({setDiag_}); +export const diagPart = op({diagPart_}) +export const bandPart = op({bandPart_}); + export const qr = op({qr_}); +export const lu = op({lu_}); +export const luSolve = op({luSolve_}); +export const cholesky = op({ cholesky_ }) +export const choleskySolve = op({ choleskySolve_}) +export const triangularSolve = op({triangularSolve_}) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 7480d7fcee..0e31405840 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -17,11 +17,14 @@ import * as tf from '../index'; import {describeWithFlags} from '../jasmine_util'; -import {Tensor1D, Tensor2D} from '../tensor'; -import {ALL_ENVS, expectArraysClose, WEBGL_ENVS} from '../test_util'; +import {Scalar, Tensor, Tensor1D, Tensor2D} from '../tensor'; +import {broadcastMatrices} from './linalg_util'; +import {CPU_ENVS, ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS, numDiff} from '../test_util'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; +const randInt = (from: number, until: number) => Math.floor(Math.random()*(until-from)) + from; + describeWithFlags('gramSchmidt-tiny', ALL_ENVS, () => { it('2x2, Array of Tensor1D', () => { const xs: Tensor1D[] = [ @@ -94,137 +97,573 @@ describeWithFlags('gramSchmidt-non-tiny', WEBGL_ENVS, () => { }); }); -describeWithFlags('qr', ALL_ENVS, () => { - it('1x1', () => { - const x = tensor2d([[10]], [1, 1]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose(q, tensor2d([[-1]], [1, 1])); - expectArraysClose(r, tensor2d([[-10]], [1, 1])); - }); +describeWithFlags('lu', CPU_ENVS, () => { + const la = tf.linalg; - it('2x2', () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, tensor2d([[-0.4472, -0.8944], [0.8944, -0.4472]], [2, 2])); - expectArraysClose(r, tensor2d([[-2.2361, -4.9193], [0, -0.8944]], [2, 2])); - }); + const testWith = (a: Tensor) => { - it('2x2x2', () => { - const x = tensor3d([[[-1, -3], [2, 4]], [[1, 3], [-2, -4]]], [2, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor3d( - [ - [[-0.4472, -0.8944], [0.8944, -0.4472]], - [[-0.4472, -0.8944], [0.8944, -0.4472]] - ], - [2, 2, 2])); - expectArraysClose( - r, - tensor3d( - [ - [[2.2361, 4.9193], [0, 0.8944]], - [[-2.2361, -4.9193], [0, -0.8944]] - ], - [2, 2, 2])); - }); - - it('2x1x2x2', () => { - const x = - tensor4d([[[[-1, -3], [2, 4]]], [[[1, 3], [-2, -4]]]], [2, 1, 2, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor4d( - [ - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - [[[-0.4472, -0.8944], [0.8944, -0.4472]]], - ], - [2, 1, 2, 2])); - expectArraysClose( - r, - tensor4d( - [ - [[[2.2361, 4.9193], [0, 0.8944]]], - [[[-2.2361, -4.9193], [0, -0.8944]]] - ], - [2, 1, 2, 2])); + function test_decomp( a: Tensor ) + { + let [lu,p] = la.lu(a); + let l = la.setDiag( la.bandPart(lu, -1, 0), tf.ones(lu.shape.slice(0,-1) ) ); + let u = la.bandPart(lu, 0,-1); + let A = tf.matMul(l,u); + + expectArraysEqual(lu.shape, a.shape ); + expectArraysEqual( p.shape, a.shape.slice(0,-1) ); + + const stride = a.shape[a.rank-2]; + + A = A.reshape( [-1].concat( A.shape.slice(-1) ) ); + a = a.reshape( [-1].concat( a.shape.slice(-1) ) ); + + p = p.reshape( [-1].concat( p.shape.slice(-1) ) ); + p = p.add( + tf.range(0,p.shape[0]*stride,stride,'int32').reshape([-1,1]) + ); + + a = a.gather( p.flatten() ); + + expectArraysClose(A,a); + }; + test_decomp(a); + + // TEST GRADIENTS + const w = tf.randomUniform(a.shape,-1,+1), + f = (a: Tensor) => la.lu(a)[0].mul(w).mean() as Scalar, + g = numDiff(f), + h = tf.grad(f); + + expectArraysClose( g(a), h(a) ); + }; + + it('2x2', () => testWith( tf.tensor2d([[1,2], + [3,4]], [2,2]) ) ); + + for( let run=32; run-- > 0; ) + { + let A_shape = Array.from({ length: randInt(2,5) }, () => randInt(1,5) ); + A_shape[A_shape.length-1] = A_shape[A_shape.length-2]; + + it(`random#${run}_${A_shape.join('x')}`, () => { + const ONE = tf.scalar(1), + TWO = tf.scalar(2); + // create a random matrix starting from a random singular value decomposition + // (this way we control condition number) + const [q1] = la.qr( tf.randomUniform(A_shape,-1,+1) ), + [q2] = la.qr( tf.randomUniform(A_shape,-1,+1) ); + + const sign = tf.randomUniform(A_shape.slice(0,-1),0,2,'int32').cast('float32').mul(TWO).sub(ONE), + magn = tf.randomNormal (A_shape.slice(0,-1),/*mean=*/1,/*stdDev=*/0.1); + + const s = la.setDiag( tf.zeros(A_shape), tf.mul(sign,magn) ); + + const a = [q1,s,q2].reduce( (a,b) => tf.matMul(a,b) ); + testWith(a); + }); + } +}); + +describeWithFlags('luSolve', CPU_ENVS, () => { + const la = tf.linalg; + + const testWith = (a: Tensor, y: Tensor) => { + let [lu,p] = la.lu(a); + + let x = la.luSolve(lu,p,y); + + const Y = tf.matMul(a,x); + + expectArraysClose(Y, tf.broadcastTo(y,Y.shape) ); + }; + + it('2x2', () => testWith( tf.tensor2d([[1,2], + [3,4]]), tf.tensor2d([[5, 6, 7], + [8, 9,10]]) ) ); + + for( let run=32; run-- > 0; ) + { + let A_shape = Array.from({ length: randInt(2,5) }, () => randInt(1,5) ), + y_shape = A_shape.slice( randInt(0,A_shape.length-2) ); + if( Math.random() < 0.5 ) + [A_shape,y_shape] = [y_shape,A_shape]; + A_shape[A_shape.length-1] = A_shape[A_shape.length-2]; + + for( let L=A_shape.length-2, y=y_shape.length-2; L-- > 0 && y-- > 0; ) + switch( randInt(0,3) ) + { + case 0: break; + case 1: A_shape[L] = 1; break; + case 2: y_shape[y] = 1; break; + } + + it(`random#${run}_${A_shape.join('x')}`, () => { + const ONE = tf.scalar(1); + const TWO = tf.scalar(2); + // create a random matrix starting from a random singular value decomposition + // (this way we control condition number) + const [q1] = la.qr( tf.randomUniform(A_shape,-1,+1) ); + const [q2] = la.qr( tf.randomUniform(A_shape,-1,+1) ); + const sign = tf.randomUniform(A_shape.slice(0,-1),0,2,'int32').cast('float32').mul(TWO).sub(ONE); + const magn = tf.randomNormal (A_shape.slice(0,-1),/*mean=*/1,/*stdDev=*/0.2); + + const s = la.setDiag( tf.zeros(A_shape), tf.mul(sign,magn) ); + + const a = [q1,s,q2].reduce( (a,b) => tf.matMul(a,b) ); + const y = tf.randomUniform(y_shape, -1, +1); + testWith(a,y); + }); + } +}); + +describeWithFlags('adjoint', ALL_ENVS, () => { + it('2x3', () => { + const a = tf.tensor2d([[1,2,3], + [4,5,6]], [2,3]), + a_T = tf.tensor2d([[1,4], + [2,5], + [3,6]],[3,2]); + expectArraysEqual( tf.linalg.adjoint(a), a_T ); + }); + it('3x2x1', () => { + const a = tf.tensor3d([[[1],[2]], + [[3],[4]], + [[5],[6]]], [3,2,1]), + a_T = tf.tensor3d([[[1,2]], + [[3,4]], + [[5,6]]], [3,1,2]); + expectArraysEqual( tf.linalg.adjoint(a), a_T ); }); +}); + + +describeWithFlags('diagPart', ALL_ENVS, () => { + const la = tf.linalg; + function testWith( a: Tensor, d: Tensor ): void + { + for( a of [la.adjoint(a),a] ) + { + expectArraysEqual(la.diagPart(a), d); + + const w = tf.randomUniform(d.shape,-1,+1), + f = (a: Tensor) => la.diagPart(a).mul(w).mean() as Scalar, + g = numDiff(f), + h = tf.grad(f); + expectArraysClose( g(a), h(a) ); + } + } + + it('3x4', () => { + let a = tf.tensor2d([ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12] + ]); + const d = tf.tensor1d([1,6,11]); + testWith(a,d); + }); it('3x3', () => { - const x = tensor2d([[1, 3, 2], [-2, 0, 7], [8, -9, 4]], [3, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d( - [ - [-0.1204, 0.8729, 0.4729], [0.2408, -0.4364, 0.8669], - [-0.9631, -0.2182, 0.1576] - ], - [3, 3])); - expectArraysClose( - r, - tensor2d( - [[-8.3066, 8.3066, -2.4077], [0, 4.5826, -2.1822], [0, 0, 7.6447]], - [3, 3])); + let a = tf.tensor2d([ + [1, 2, 3], + [5, 6, 7], + [9,10,11] + ]); + const d = tf.tensor1d([1,6,11]); + testWith(a,d); }); - - it('3x2, fullMatrices = default false', () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d( - [[-0.2673, 0.9221], [-0.8018, -0.3738], [0.5345, -0.0997]], - [3, 2])); - expectArraysClose(r, tensor2d([[-3.7417, 2.4054], [0, 2.8661]], [2, 2])); + it('2x2x3', () => { + let a = tf.tensor3d( + [[[ 1, 2, 3], + [ 4, 5, 6]], + [[ 7, 8, 9], + [10,11,12]]] + ); + const d = tf.tensor2d([[1,5],[7,11]]); + testWith(a,d); }); +}) - it('3x2, fullMatrices = true', () => { - const x = tensor2d([[1, 2], [3, -3], [-2, 1]], [3, 2]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose( - q, - tensor2d( - [ - [-0.2673, 0.9221, 0.2798], [-0.8018, -0.3738, 0.4663], - [0.5345, -0.0997, 0.8393] - ], - [3, 3])); - expectArraysClose( - r, tensor2d([[-3.7417, 2.4054], [0, 2.8661], [0, 0]], [3, 2])); - }); - it('2x3, fullMatrices = default false', () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x); - expectArraysClose( - q, - tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2])); - expectArraysClose( - r, - tensor2d( - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]), +describeWithFlags('setDiag', ALL_ENVS, () => { + const la = tf.linalg; + + function testWith( a: Tensor, d: Tensor, b: Tensor ): void + { + for( let i=2; i-- > 0; ) + { + expectArraysEqual(la.setDiag(a,d), b); + + const w = tf.randomUniform(a.shape,-1,+1), + f = (a: Tensor, d: Tensor) => la.setDiag(a,d).mul(w).mean() as Scalar, + g1 = numDiff( a => f(a,d) )(a), + g2 = numDiff( d => f(a,d) )(d), + [h1,h2] = tf.grads(f)([a,d]); + expectArraysClose(g1,h1); + expectArraysClose(g2,h2); + + a = la.adjoint(a); + b = la.adjoint(b); + } + } + + it('3x4', () => { + const a = tf.tensor2d([ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12] + ]); + const d = tf.tensor1d([13,14,15]); + const b = tf.tensor2d([ + [13, 2, 3, 4], + [5, 14, 7, 8], + [9, 10,15,12] + ]); + testWith(a,d,b); + }); + it('3x3', () => { + const a = tf.tensor2d([ + [1, 2, 3], + [5, 6, 7], + [9,10,11] + ]); + const d = tf.tensor1d([12,13,14]); + const b = tf.tensor2d([ + [12, 2, 3], + [5, 13, 7], + [9, 10,14] + ]); + testWith(a,d,b); + }); + it('2x2x3', () => { + const a = tf.tensor3d( + [[[ 1, 2, 3], + [ 4, 5, 6]], + [[ 7, 8, 9], + [10,11,12]]] ); + const d = tf.tensor2d([[13,14], + [15,16]]); + const b = tf.tensor3d( + [[[13, 2, 3], + [ 4,14, 6]], + [[15, 8, 9], + [10,16,12]]] + ); + testWith(a,d,b); }); +}) - it('2x3, fullMatrices = true', () => { - const x = tensor2d([[1, 2, 3], [-3, -2, 1]], [2, 3]); - const [q, r] = tf.linalg.qr(x, true); - expectArraysClose( - q, - tensor2d([[-0.3162278, -0.9486833], [0.9486833, -0.31622773]], [2, 2])); - expectArraysClose( - r, - tensor2d( - [[-3.162, -2.5298, -2.3842e-07], [0, -1.2649, -3.162]], [2, 3]), - ); + +describeWithFlags('bandPart', ALL_ENVS, () => { + it('3x4', () => { + const a = tf.tensor2d([ + [1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12] + ]); + expectArraysEqual( tf.linalg.bandPart(a,0,0), tf.tensor2d([[1,0, 0, 0], + [0,6, 0, 0], + [0,0,11, 0]]) ); + expectArraysEqual( tf.linalg.bandPart(a,0,1), tf.tensor2d([[1,2, 0, 0], + [0,6, 7, 0], + [0,0,11,12]]) ); + expectArraysEqual( tf.linalg.bandPart(a,0,2), tf.tensor2d([[1,2, 3, 0], + [0,6, 7, 8], + [0,0,11,12]]) ); + expectArraysEqual( tf.linalg.bandPart(a,0,2), tf.tensor2d([[1,2, 3, 0], + [0,6, 7, 8], + [0,0,11,12]]) ); + for( const numUpper of [3,4,-1,-2] ) + expectArraysEqual( tf.linalg.bandPart(a,0,numUpper), tf.tensor2d([[1,2, 3, 4], + [0,6, 7, 8], + [0,0,11,12]]) ); + + + expectArraysEqual( tf.linalg.bandPart(a,1,0), tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [0,10,11, 0]]) ); + expectArraysEqual( tf.linalg.bandPart(a,1,1), tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [0,10,11,12]]) ); + expectArraysEqual( tf.linalg.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) ); + expectArraysEqual( tf.linalg.bandPart(a,1,2), tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [0,10,11,12]]) ); + for( const numUpper of [3,4,-1,-2] ) + expectArraysEqual( tf.linalg.bandPart(a,1,numUpper), tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [0,10,11,12]]) ); + + + for( const numLower of [2,3,-1,-2]) + { + expectArraysEqual( tf.linalg.bandPart(a,numLower,0), tf.tensor2d([[1, 0, 0, 0], + [5, 6, 0, 0], + [9,10,11, 0]]) ); + expectArraysEqual( tf.linalg.bandPart(a,numLower,1), tf.tensor2d([[1, 2, 0, 0], + [5, 6, 7, 0], + [9,10,11,12]]) ); + expectArraysEqual( tf.linalg.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) ); + expectArraysEqual( tf.linalg.bandPart(a,numLower,2), tf.tensor2d([[1, 2, 3, 0], + [5, 6, 7, 8], + [9,10,11,12]]) ); + for( const numUpper of [3,4,-1,-2] ) + expectArraysEqual( tf.linalg.bandPart(a,numLower,numUpper), tf.tensor2d([[1, 2, 3, 4], + [5, 6, 7, 8], + [9,10,11,12]]) ); + } + + for( const numUpper of [0,1,2,3,4,-1,-2] ) + for( const numLower of [0,1,2,3, -1,-2] ) { + const w = tf.randomUniform(a.shape), + f = (x: Tensor) => tf.linalg.bandPart(x,numLower,numUpper).mul(w).mean() as Scalar, + g = numDiff(f), + h = tf.grad(f); + expectArraysClose( g(a), h(a) ); + }; + }); +}); + + +describeWithFlags('cholesky', ALL_ENVS, () => { + const la = tf.linalg; + + function testWith( l: Tensor ) { + const a = tf.matMul(l,l,false,true); + const L = la.cholesky(a); + expectArraysClose(L,l); + + const w = tf.randomUniform(a.shape,-1,+1), + f = (a: Tensor) => la.cholesky(a).mul(w).mean() as Scalar, + g = numDiff(f), + h = tf.grad(f); + + expectArraysClose( h(a), g(a) ); + }; + + for( let run=128; run-- > 0; ) + { + let L_shape = Array.from({ length: randInt(2,5) }, () => randInt(1,5) ), + d_shape = L_shape.slice(0,-1); + L_shape[L_shape.length-1] = L_shape[L_shape.length-2]; + + // RUN TEST + it(`random#${run}_${L_shape.join('x')}`, () => { + let l = tf.randomUniform(L_shape,-0.2,+0.2) + l = la.bandPart(l,-1, 0); + l = la.setDiag( l, tf.randomNormal(d_shape,/*mean=*/1,/*stdDev=*/0.2) ); + testWith(l); + }); + } +}) + + +describeWithFlags('triangularSolve', ALL_ENVS, () => { + const testWith = (L: Tensor, y: Tensor) => { + const test = (adjoint: boolean) => + { + let tril = tf.linalg.bandPart(L,-1, 0), + triu = tf.linalg.bandPart(L, 0,-1); + if( adjoint ) { + tril = tf.linalg.adjoint(tril); + triu = tf.linalg.adjoint(triu); + } + for( const lower of [true,undefined] ) + { + const x = tf.linalg.triangularSolve(L,y, lower, adjoint); + const [a,b] = broadcastMatrices( y, tril.matMul(x) ); + expectArraysClose(a,b); + } + const x = tf.linalg.triangularSolve(L,y, /*lower=*/false, adjoint); + const [a,b] = broadcastMatrices( y, triu.matMul(x) ); + expectArraysClose(a,b); + + for( const lower of [false,true,undefined] ) + { + const w = tf.randomUniform(y.shape,-1,+1), + f = (L: Tensor, y: Tensor) => { + return tf.linalg.triangularSolve(L,y,lower).mul(w).mean() as Scalar + }, + [g1,g2] = tf.grads(f)([L,y]), + h1 = numDiff( (L: Tensor) => f(L,y) )(L), + h2 = numDiff( (y: Tensor) => f(L,y) )(y); + expectArraysClose(g1,h1); + expectArraysClose(g2,h2); + } + } + test(undefined); + test(false); + test(true); + }; + + it('3x3', () => testWith( + tf.tensor2d( + [[1,2,3], + [4,5,6], + [7,8,9]] + ), + tf.tensor2d( + [[10,11], + [12,13], + [14,15]] + ) + )); + + for( let run=0; run < 16; run++ ) + { // RANDOMLY GENERATE BROADCAST-COMPATIBLE SHAPES + let L_shape = Array.from({ length: randInt(2,5) }, () => randInt(1,4) ), + y_shape = L_shape.slice( randInt(0,L_shape.length-2) ); + + if( Math.random() < 0.5 ) + { + y_shape = Array.from({ length: randInt(2,5) }, () => randInt(1,4) ); + L_shape = y_shape.slice( randInt(0,y_shape.length-2) ); + } + L_shape[L_shape.length-1] = L_shape[L_shape.length-2]; + + for( let L=L_shape.length-2, y=y_shape.length-2; L-- > 0 && y-- > 0; ) + switch( randInt(0,3) ) + { + case 0: break; + case 1: L_shape[L] = 1; break; + case 2: y_shape[y] = 1; break; + } + + // RUN TEST + it(`random#${run}_${L_shape.join('x')}_${y_shape.join('x')}`, () => { + const ONE = tf.scalar(1), + TWO = tf.scalar(2); + const y = tf.randomUniform(y_shape,-1,+1); + let L: Tensor = tf.randomUniform(L_shape,-1,+1); + // SET THE DIAGONAL TO BE FAR FROM ZERO + const i = tf.range(0,L_shape[L_shape.length-2]).reshape([-1,1]), + j = tf.range(0,L_shape[L_shape.length-1]), + diag = tf.equal(i,j).cast('float32'), + sign = tf.randomUniform(L_shape,0,2,'int32').cast('float32').mul(TWO).sub(ONE), + magn = tf.randomNormal (L_shape, /*mean=*/1,/*stdDev=*/0.1); + L = tf.add( + diag.sub(ONE).mul(L), // <- off-diagonal + diag.mul(sign).mul(magn) // <- diagonal + ); + L = tf.clone(L); + testWith(L,y); + }); + } +}); + + +describeWithFlags('qr', ALL_ENVS, () => { + const testWith = (a: Tensor) => { + const [m,n] = a.shape.slice(-2), + l = Math.min(m,n), + T = Array.from({ length: a.rank }, (_,i) => i ); + T[T.length-2] = T.length-1; + T[T.length-1] = T.length-2; + + for( const fullMatrices of [undefined,false,true] ) + { + const tril = function(){ + const [p,q] = fullMatrices ? [m,n] : [l,n], + i = tf.range(0,p).reshape([p,1]), + j = tf.range(0,q).reshape([1,q]); + return i.greater(j).cast('float32'); + }(); + const EYE = function(){ + const d = fullMatrices ? m : l; + return tf.stack( + Array.from( + { length: a.shape.slice(0,-2).reduce( (x,y) => x*y, 1 ) }, + () => tf.eye(d) + ) + ).reshape([...a.shape.slice(0,-2),d,d]); + }(); + let [q,r] = tf.linalg.qr(a,fullMatrices); + const q_T = q.transpose(T); + + // TEST SHAPE OF Q + expectArraysEqual( q.shape.slice(0,-1), a.shape.slice(0,-1) ); + expectArraysClose( q.shape.slice( -1), fullMatrices ? [m ] : [l ] ); + + // TEST SHAPE OF R + expectArraysEqual( r.shape.slice(0,-2), a.shape.slice(0,-2) ); + expectArraysClose( r.shape.slice( -2), fullMatrices ? [m,n] : [l,n] ); + + // TEST DECOMPOSITION (Q @ R == A) + expectArraysClose( q.matMul(r), a ); + + // TEST ORTHOGONALITY OF Q + expectArraysClose( q_T.matMul(q ), EYE ); + if( fullMatrices || n >= m ) expectArraysClose( q .matMul(q_T), EYE ); + + // TEST TRIANGULARITY OF R + expectArraysEqual( tril.mul(r), tf.zeros(r.shape) ); + + // TEST GRADIENTS + const wQ = tf.randomUniform(q.shape,-1,+1), + wR = tf.randomUniform(r.shape,-1,+1), + f = (a: Tensor) => { + const [q,r] = tf.linalg.qr(a,fullMatrices); + return q.mul(wQ).mean().add( r.mul(wR).mean() ) as Scalar; // <= FIXME: use some weights + }; + const g = numDiff(f); + const h = tf.grad(f); + expectArraysClose( g(a), h(a) ); + } + }; + + it('1x1', () => testWith( tensor2d([[10]], [1, 1]) ) ); + + it('2x2', () => testWith( tensor2d([[ 1, 3], + [-2,-4]], [2, 2]) ) ); + + it('2x2x2', () => testWith( tensor3d([[[-1,-3], + [ 2, 4]], + [[ 1, 3], + [-2,-4]]], [2, 2, 2]) ) ); + + it('2x1x2x2', () => testWith( tensor4d([[[[-1,-3], + [ 2, 4]]], + [[[ 1, 3], + [-2,-4]]]], [2, 1, 2, 2]) ) ); + + it('3x3', () => testWith( tensor2d([[ 1, 3, 2], + [-2, 0, 7], + [ 8,-9, 4]], [3, 3]) ) ); + + it('3x2', () => testWith( tensor2d([[ 1, 2], + [ 3,-3], + [-2, 1]], [3, 2]) ) ); + + it('2x3', () => testWith( tensor2d([[ 1, 2, 3], + [-3,-2, 1]], [2, 3]) ) ); + + for( let run=0; run < 16; run++ ) + { + const A_shape = Array.from({ length: randInt(2,5) }, () => randInt(1,4) ); + it(`random#${run}_${A_shape.join('x')}`, () => testWith( tf.randomUniform(A_shape,-1,+1) )); + } + + it('Is reasonably fast', () => { // <- TODO is there a better way to test this with a timeout? + const N = 128, + A = tf.randomUniform([N,N],-1,+1), + wQ = tf.randomUniform([N,N],-1,+1), + wR = tf.randomUniform([N,N],-1,+1), + f = (a: Tensor) => { + const [q,r] = tf.linalg.qr(a); + return q.mul(wQ).mean().add( r.mul(wR).mean() ); // <= FIXME: use some weights + }; + const g = tf.grad(f); + expectArraysClose( g(A), g(A) ); // <- this hopefully prevent g(A) from being JITes/Optimized away... }); it('Does not leak memory', () => { - const x = tensor2d([[1, 3], [-2, -4]], [2, 2]); + const x = tensor2d([[ 1, 3], + [-2,-4]], [2, 2]); // The first call to qr creates and keeps internal singleton tensors. // Subsequent calls should always create exactly two tensors. tf.linalg.qr(x); diff --git a/src/ops/linalg_util.ts b/src/ops/linalg_util.ts new file mode 100644 index 0000000000..74cc9005ad --- /dev/null +++ b/src/ops/linalg_util.ts @@ -0,0 +1,54 @@ +/** + * @license + * Copyright 2018 Google LLC. All Rights Reserved. + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + * ============================================================================= + */ + +/** + * Linear algebra ops utility methods. + */ + +import {Tensor} from '../tensor'; +import {broadcastTo} from './array_ops' + +/** Broadcasts the given tensors of matrices iff necessary. This is a matrix + * broadcast, i.e. the last two dimensions of every tensor are neither broadcast + * nor checked in any way. + */ +export function broadcastMatrices( ...tensors: Tensor[] ): Tensor[] +{ + if( tensors.length < 2 ) + throw new Error('broadcastMatrices(): At least two tensors expected.'); + for( const tensor of tensors ) + if( tensor.rank < 2 ) throw new Error('broadcastMatrices(): At least one tensor has rank < 2.'); + +// console.log('SHAPES:', tensors.map( t => t.shape ) ); + + const rank: number = tensors.map( x => x.rank ).reduce( (r,s) => Math.max(r,s) ), + shape: number[] = Array.from({ length: rank }, () => 1 ); + + // FIND COMMON (BROADCASTED) SHAPE + for( const tensor of tensors ) + for( let i=rank-2, j=tensor.rank-2; i-- > 0 && j-- > 0; ) + if( 1 === shape[i] ) + shape[i] = tensor.shape[j]; + else if( shape[i] != tensor.shape[j] && tensor.shape[j] != 1 ) + throw new Error(`triangularSolve(): Shapes not broadcast-compatible.`); + + return tensors.map( tensor => { + shape[shape.length-2] = tensor.shape[tensor.rank-2]; + shape[shape.length-1] = tensor.shape[tensor.rank-1]; + return broadcastTo( tensor, shape.slice() );// <- make protection copy of shape (better safe than sorry) + }); +} \ No newline at end of file diff --git a/src/ops/matmul.ts b/src/ops/matmul.ts index 630d428242..eeb1097d51 100644 --- a/src/ops/matmul.ts +++ b/src/ops/matmul.ts @@ -21,6 +21,7 @@ import {convertToTensor} from '../tensor_util_env'; import {TensorLike} from '../types'; import * as util from '../util'; import {op} from './operation'; +import {broadcastMatrices} from './linalg_util'; /** * Computes the dot product of two matrices, A * B. These must be matrices. @@ -40,8 +41,10 @@ import {op} from './operation'; function matMul_( a: T|TensorLike, b: T|TensorLike, transposeA = false, transposeB = false): T { - const $a = convertToTensor(a, 'a', 'matMul'); - const $b = convertToTensor(b, 'b', 'matMul'); + const [$a,$b] = broadcastMatrices( + convertToTensor(a, 'a', 'matMul'), + convertToTensor(b, 'b', 'matMul') + ); const innerShapeA = transposeA ? $a.shape[$a.rank - 2] : $a.shape[$a.rank - 1]; diff --git a/src/test_util.ts b/src/test_util.ts index 44cd997c35..aaebd2d30e 100644 --- a/src/test_util.ts +++ b/src/test_util.ts @@ -17,10 +17,53 @@ import {ENV} from './environment'; import {Features} from './environment_util'; -import {Tensor} from './tensor'; +import {Scalar, Tensor} from './tensor'; import {TypedArray} from './types'; import * as util from './util'; +/** Computes the gradients using finite differences. + * + * SEE: https://en.wikipedia.org/wiki/Finite_difference#Forward,_backward,_and_central_differences + * + * FIXME this is terribly imprecise... wish there was double precision support *hint hint*. + */ +export const numDiff = (f: (_: Tensor) => Scalar) => (a: Tensor) => { + if( a.dtype !== 'float32' ) + throw new Error(`numDiff(): dtype=${a.dtype} not supported.`); + + const dA_shape = a.shape, + A = Float32Array.from( a.dataSync() ), + d = 2**-10; + + function val( scalar: Tensor ): number + { + if( scalar.rank !== 0 ) + throw new Error('f() returned a non-scalar value.'); + return scalar.dataSync()[0]; + } + + return ENV.engine.tidy(() => { + a = Tensor.make(dA_shape,{values: A}); + + const dA = new Float32Array( A.length ); + + for( let i=0; i < A.length; i++ ) + { // use central difference + const A_i = A[i], + A_hi = A_i + d, + A_lo = A_i - d; + + // DISPOSAL (HOPEFULLY) REMOVES DATA FROM GPU AND FORCES REUPLOAD + A[i] = A_lo; a.dispose(); a = Tensor.make(dA_shape,{values: A}); const F_lo = val(f(a)); + A[i] = A_hi; a.dispose(); a = Tensor.make(dA_shape,{values: A}); const F_hi = val(f(a)); + dA[i] = (F_hi - F_lo) / (A_hi - A_lo); + A[i] = A_i; + } + + return Tensor.make(dA_shape,{values: dA}); + }); +}; + // TODO(smilkov): Move these constants to jasmine_util. export const WEBGL_ENVS: Features = { 'HAS_WEBGL': true From d29b47a91fd60b627238210722bf801d6556de01 Mon Sep 17 00:00:00 2001 From: Dirk Toewe Date: Sat, 27 Oct 2018 19:44:10 +0200 Subject: [PATCH 2/2] Fixed LU pivoting & testing --- src/ops/linalg_ops.ts | 12 ++++---- src/ops/linalg_ops_test.ts | 56 ++++++++++++++++++++++++++------------ src/test_util.ts | 2 +- 3 files changed, 45 insertions(+), 25 deletions(-) diff --git a/src/ops/linalg_ops.ts b/src/ops/linalg_ops.ts index d31e5f8954..4953767462 100644 --- a/src/ops/linalg_ops.ts +++ b/src/ops/linalg_ops.ts @@ -145,9 +145,10 @@ function lu_p_decomp_( a: Tensor ): [Tensor,Tensor] { let p=i; for( let j=i+1; j < N; j++ ) - if( Math.abs( LU[LU_off+P[P_off+j]*N+i] ) - > Math.abs( LU[LU_off+P[P_off+p]*N+i] ) ) + if( Math.abs( LU[LU_off + N*j+i] ) + > Math.abs( LU[LU_off + N*p+i] ) ) p=j; + if( i != p ) { const P_p = P[P_off+i]; P[P_off+i] = P[P_off+p]; @@ -1078,7 +1079,7 @@ function qr_( a: Tensor, fullMatrices = false ): [Tensor, Tensor] { * * @returns b Tensor of shape [...,M,N], such that: `b[...,p[...,i],j] == a[...,i,j]`. */ -function permuteRowsInv_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this? +export function permuteRowsInv_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this? { if( a.rank != p.rank+1 ) throw new Error(`permuteRows(): a.rank and p.rank do not match.`); for( let i=p.rank; i-- > 0; ) @@ -1090,9 +1091,8 @@ function permuteRowsInv_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this P_inv = new Int32Array(P.length); for( let off=0; off < P.length; off += N ) - for( let i =0; i < N; i++ ) { + for( let i =0; i < N ; i++ ) P_inv[off + P[off+i]] = off+i; - } return gather( a.reshape( [-1].concat( a.shape.slice(-1) ) ), @@ -1107,7 +1107,7 @@ function permuteRowsInv_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this * * @returns b Tensor of shape [...,M,N], such that: `b[...,i,j] == a[...,p[...,i],j]`. */ -function permuteRows_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this? +export function permuteRows_( a: Tensor, p: Tensor ): Tensor // <- TODO: export this? { if( a.rank != p.rank+1 ) throw new Error(`permuteRows(): a.rank and p.rank do not match.`); for( let i=p.rank; i-- > 0; ) diff --git a/src/ops/linalg_ops_test.ts b/src/ops/linalg_ops_test.ts index 0e31405840..97d4468e36 100644 --- a/src/ops/linalg_ops_test.ts +++ b/src/ops/linalg_ops_test.ts @@ -20,6 +20,7 @@ import {describeWithFlags} from '../jasmine_util'; import {Scalar, Tensor, Tensor1D, Tensor2D} from '../tensor'; import {broadcastMatrices} from './linalg_util'; import {CPU_ENVS, ALL_ENVS, expectArraysClose, expectArraysEqual, WEBGL_ENVS, numDiff} from '../test_util'; +import {permuteRows_, permuteRowsInv_} from './linalg_ops'; import {scalar, tensor1d, tensor2d, tensor3d, tensor4d} from './ops'; @@ -109,32 +110,48 @@ describeWithFlags('lu', CPU_ENVS, () => { let u = la.bandPart(lu, 0,-1); let A = tf.matMul(l,u); + expectArraysEqual( l.abs().max(), tf.scalar(1) ); + expectArraysEqual(lu.shape, a.shape ); expectArraysEqual( p.shape, a.shape.slice(0,-1) ); - const stride = a.shape[a.rank-2]; - - A = A.reshape( [-1].concat( A.shape.slice(-1) ) ); - a = a.reshape( [-1].concat( a.shape.slice(-1) ) ); - - p = p.reshape( [-1].concat( p.shape.slice(-1) ) ); - p = p.add( - tf.range(0,p.shape[0]*stride,stride,'int32').reshape([-1,1]) - ); - - a = a.gather( p.flatten() ); - - expectArraysClose(A,a); + expectArraysClose( A, permuteRows_(a,p) ); + expectArraysClose( permuteRowsInv_(A,p), a ); }; test_decomp(a); // TEST GRADIENTS + const perm: Int32Array[] = []; const w = tf.randomUniform(a.shape,-1,+1), - f = (a: Tensor) => la.lu(a)[0].mul(w).mean() as Scalar, + f = (a: Tensor) => { + const [lu,p] = la.lu(a); + perm.push( p.dataSync() as Int32Array ); + return lu.mul(w).mean() as Scalar; + }, g = numDiff(f), h = tf.grad(f); - expectArraysClose( g(a), h(a) ); + try { + const g_a = g(a); + const noChangeInPerm = perm.slice(1).every( + p => perm[0].every( (q,i) => q == p[i] ) + ); + // if the permutation changes due to the finite difference this may drastically change the gradients + if( noChangeInPerm ) + expectArraysClose( g_a, h(a) ); + else + console.log( + "lu(): p changed during finite difference calculation. Skipping test." + + "\nIt's perfectly normal for this to happend once or twice." + ); + } + catch(err) { + console.log('A' ); a .print(); + console.log('LU'); la.lu(a)[0].print(); + console.log('G' ); g(a).print(); + console.log('H' ); h(a).print(); + throw err; + } }; it('2x2', () => testWith( tf.tensor2d([[1,2], @@ -153,8 +170,11 @@ describeWithFlags('lu', CPU_ENVS, () => { const [q1] = la.qr( tf.randomUniform(A_shape,-1,+1) ), [q2] = la.qr( tf.randomUniform(A_shape,-1,+1) ); - const sign = tf.randomUniform(A_shape.slice(0,-1),0,2,'int32').cast('float32').mul(TWO).sub(ONE), - magn = tf.randomNormal (A_shape.slice(0,-1),/*mean=*/1,/*stdDev=*/0.1); + const magn = tf.randomNormal (A_shape.slice(0,-1),/*mean=*/1,/*stdDev=*/0.2), + sign = tf.randomUniform(A_shape.slice(0,-1),0,2,'int32') + .cast('float32') + .mul(TWO) + .sub(ONE); const s = la.setDiag( tf.zeros(A_shape), tf.mul(sign,magn) ); @@ -205,7 +225,7 @@ describeWithFlags('luSolve', CPU_ENVS, () => { const [q1] = la.qr( tf.randomUniform(A_shape,-1,+1) ); const [q2] = la.qr( tf.randomUniform(A_shape,-1,+1) ); const sign = tf.randomUniform(A_shape.slice(0,-1),0,2,'int32').cast('float32').mul(TWO).sub(ONE); - const magn = tf.randomNormal (A_shape.slice(0,-1),/*mean=*/1,/*stdDev=*/0.2); + const magn = tf.randomNormal (A_shape.slice(0,-1),/*mean=*/1,/*stdDev=*/0.1); const s = la.setDiag( tf.zeros(A_shape), tf.mul(sign,magn) ); diff --git a/src/test_util.ts b/src/test_util.ts index aaebd2d30e..63e444a951 100644 --- a/src/test_util.ts +++ b/src/test_util.ts @@ -33,7 +33,7 @@ export const numDiff = (f: (_: Tensor) => Scalar) => (a: Tensor) => { const dA_shape = a.shape, A = Float32Array.from( a.dataSync() ), - d = 2**-10; + d = 2**-11; function val( scalar: Tensor ): number {