import { getCoordsDataType } from './shader_compiler';
export class GatherNDProgram {
  constructor(sliceDim, strides, shape, paramsShape) {
    this.sliceDim = sliceDim;
    this.strides = strides;
    this.paramsShape = paramsShape;
    this.variableNames = ['x', 'indices'];
    this.outputShape = shape;
    const dtype = getCoordsDataType(shape.length);
    let mainLoop = `
    int index;`;
    for (let j = 0; j < this.sliceDim; j++) {
      mainLoop += `
          index = round(getIndices(coords[0], ${j}));
          out_of_bounds = out_of_bounds || index < 0;
          out_of_bounds = out_of_bounds || index >= ${this.paramsShape[j]};
          flattenIndex += index * ${this.strides[j]};`;
    }
    this.userCode = `
         void main() {
          ${dtype} coords = getOutputCoords();
          int flattenIndex = 0;
          bool out_of_bounds = false;

          ${mainLoop}

          setOutput(out_of_bounds ? 0.0 : getX(flattenIndex, coords[1]));
        }
      `;
  }
}
